prithivMLmods commited on
Commit
91cda81
·
verified ·
1 Parent(s): 5d63d59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -52
app.py CHANGED
@@ -14,31 +14,28 @@ MODEL_OPTIONS = {
14
  "Text Analogy Ocrtest": "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
15
  }
16
 
17
- # Global variables for model and processor
18
- model = None
19
- processor = None
20
-
21
- # Function to load the selected model
22
- def load_model(model_name):
23
- global model, processor
24
- model_id = MODEL_OPTIONS[model_name]
25
- print(f"Loading model: {model_id}")
26
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
27
- model = Qwen2VLForConditionalGeneration.from_pretrained(
28
- model_id,
29
- trust_remote_code=True,
30
- torch_dtype=torch.float16
31
- ).to("cuda").eval()
32
- print(f"Model {model_id} loaded successfully!")
33
- return f"Model {model_name} loaded!"
34
 
35
  @spaces.GPU
36
- def model_inference(input_dict, history, model_choice):
37
  global model, processor
38
 
39
- # Load the selected model if not already loaded
40
- if model is None or processor is None:
41
- load_model(model_choice)
 
 
 
 
 
 
42
 
43
  text = input_dict["text"]
44
  files = input_dict["files"]
@@ -107,35 +104,21 @@ examples = [
107
  [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
108
  ]
109
 
110
- # Gradio interface
111
- with gr.Blocks() as demo:
112
- gr.Markdown("# **Qwen2.5-VL-3B-Instruct**")
113
-
114
- # Model selection dropdown
115
- model_choice = gr.Dropdown(
116
- label="Model Selection",
117
- choices=list(MODEL_OPTIONS.keys()),
118
- value="Latex OCR"
119
- )
120
-
121
- # Load model button
122
- load_model_btn = gr.Button("Load Model")
123
- load_model_output = gr.Textbox(label="Model Load Status")
124
-
125
- # Chat interface
126
- chat_interface = gr.ChatInterface(
127
- fn=model_inference,
128
- description="Interact with the selected Qwen2-VL model.",
129
- examples=examples,
130
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
131
- stop_btn="Stop Generation",
132
- multimodal=True,
133
- cache_examples=False,
134
- additional_inputs=[model_choice] # Pass model_choice as an additional input
135
- )
136
-
137
- # Link the load model button to the load_model function
138
- load_model_btn.click(load_model, inputs=model_choice, outputs=load_model_output)
139
-
140
- # Launch the demo
141
  demo.launch(debug=True)
 
14
  "Text Analogy Ocrtest": "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
15
  }
16
 
17
+ # Default model setup
18
+ current_model_id = MODEL_OPTIONS["Latex OCR"]
19
+ processor = AutoProcessor.from_pretrained(current_model_id, trust_remote_code=True)
20
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
21
+ current_model_id,
22
+ trust_remote_code=True,
23
+ torch_dtype=torch.float16
24
+ ).to("cuda").eval()
 
 
 
 
 
 
 
 
 
25
 
26
  @spaces.GPU
27
+ def model_inference(input_dict, history, model_id):
28
  global model, processor
29
 
30
+ # Reload the model and processor if the model selection changes
31
+ if model_id != current_model_id:
32
+ current_model_id = model_id
33
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
34
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
35
+ model_id,
36
+ trust_remote_code=True,
37
+ torch_dtype=torch.float16
38
+ ).to("cuda").eval()
39
 
40
  text = input_dict["text"]
41
  files = input_dict["files"]
 
104
  [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
105
  ]
106
 
107
+ # Gradio components
108
+ model_choice = gr.Dropdown(
109
+ label="Model Selection",
110
+ choices=list(MODEL_OPTIONS.keys()),
111
+ value="Latex OCR"
112
+ )
113
+
114
+ demo = gr.ChatInterface(
115
+ fn=lambda inputs, history: model_inference(inputs, history, MODEL_OPTIONS[model_choice.value]),
116
+ description="# **Qwen2.5-VL-3B-Instruct**",
117
+ examples=examples,
118
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
119
+ stop_btn="Stop Generation",
120
+ multimodal=True,
121
+ cache_examples=False,
122
+ )
123
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  demo.launch(debug=True)