File size: 3,322 Bytes
575baf4
 
 
 
 
45bf11a
 
 
575baf4
a0c2076
 
575baf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45bf11a
1767d02
575baf4
1767d02
575baf4
 
 
 
 
 
 
 
 
 
 
 
 
25fc3a5
575baf4
 
1767d02
575baf4
 
 
cc1fd5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel, GPT2Tokenizer, pipeline
import os

HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1

device = 'cpu'
auth_token = os.getenv("auth_token")
#auth_token = os.environ.get("auth_token")

max_length = 100
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def predict_step(image_paths, model):
  images = []
  for image_path in image_paths:
    i_image = Image.open(image_path)
    if i_image.mode != "RGB":
      i_image = i_image.convert(mode="RGB")

    images.append(i_image)

  pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
  pixel_values = pixel_values.to(device)

  output_ids = model.generate(pixel_values, **gen_kwargs)

  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]
  return preds

def predict_step_image(dataset_images, feature_extractor, model):
  results = []
  for i in dataset_images:
    pixel_values = feature_extractor(images=i, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    output_ids = model.generate(pixel_values, **gen_kwargs)

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    results.append(preds)
  return results
 
def predict_step_single_image(image, tokenizer, feature_extractor, model):
    results=[]
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    output_ids = model.generate(pixel_values, **gen_kwargs)

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    results.append(preds)
    return results

def predict_step_pixel(dataset_pixel_values, model):
  results=[]
  for pv in dataset_pixel_values:
    pixel_values = pv.reshape([1,3,224,224])
    pixel_values = pixel_values.to(device)
    output_ids = model.generate(pixel_values, **gen_kwargs)
    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    results.append([pred.strip() for pred in preds][0])
  return results

"""
    image methods
"""
def load_image2txt_model(image_model_name):
    model = VisionEncoderDecoderModel.from_pretrained(image_model_name)
    feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-large-patch4-window7-224", use_auth_token=auth_token)
    
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2", use_auth_token=auth_token)
    tokenizer.pad_token = tokenizer.eos_token

    model = model.to(device)
    return tokenizer, feature_extractor, model

def inference_image_pipe(image_input):
    image_model_name = "./checkpoint-21000"

    tokenizer, feature_extractor, image_model = load_image2txt_model(image_model_name)
    text = predict_step_single_image(image_input, tokenizer, feature_extractor, image_model)[0]
    return text

with gr.Interface(fn=inference_image_pipe, 
             inputs=gr.Image(height=256, width=256),
             outputs="text",
             examples=["3212210S4492629-1.png", "3216497S4499373-1.png"]) as demo:
    gr.Markdown("POC XRaySwinGen - Automatic Medical Report")
    
    
if __name__ == "__main__":
    demo.launch()