File size: 2,453 Bytes
f34dfa2
76f08bb
bc34ba5
4c65203
 
 
 
cfea9cd
146193f
4c65203
 
 
 
 
 
 
 
 
 
 
 
e06335c
 
 
 
 
 
 
 
 
 
 
 
4c65203
 
 
2c31a70
4c65203
 
 
12dc835
4c65203
 
 
 
 
 
2c31a70
a92b1d4
4c65203
a92b1d4
e06335c
 
bc34ba5
f34dfa2
 
 
 
2c31a70
 
f34dfa2
2c31a70
 
f34dfa2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import spaces
import markdown
import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor

SYSTEM_INSTRUCTION="You are a medical report interpreter. Your task is to analyze the provided medical reports, identify key medical terms, diagnoses, or abnormalities, and provide a clear interpretation. Based on your analysis, generate a detailed summary that includes an explanation of the findings, recommended actions, and any additional insights for the patient or healthcare provider. Ensure your output is structured and easily understandable for both professionals and non-professionals."


model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)


def extract_assistant_reply(input_string):
    # Define the tag that indicates the start of the assistant's reply
    start_tag = "<|start_header_id|>assistant<|end_header_id|>"
    # Find the position where the assistant's reply starts
    start_index = input_string.find(start_tag)
    if start_index == -1:
        return "Assistant's reply not found."
    start_index += len(start_tag)
    # Extract everything after the start tag
    assistant_reply = input_string[start_index:].strip()
    return assistant_reply



@spaces.GPU
def med_interpreter(image):
    messages = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": SYSTEM_INSTRUCTION}
        ]}
    ]
    input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(image, input_text, return_tensors="pt").to(model.device)

    # Generate the output from the model
    output = model.generate(**inputs, max_new_tokens=4000)
    print(output)
    markdown_text = processor.decode(output[0])
    print(markdown_text)
    
    markdown_text=extract_assistant_reply(markdown_text)
    html_output = markdown.markdown(markdown_text)
    return html_output

# Gradio UI
interface = gr.Interface(
    fn=med_interpreter,
    inputs=gr.Image(type="pil", label="Upload an image of the medical report"),
    outputs=gr.HTML(),
    title="Medical Report Insights",
    description="Upload an image of your medical report to get the interperation of it"
)

# Launch the UI
interface.launch()