CYF200127 commited on
Commit
f1f2574
·
verified ·
1 Parent(s): 306ff4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -26
app.py CHANGED
@@ -5,6 +5,10 @@ from rxnim import RXNIM
5
  from getReaction import generate_combined_image
6
  import torch
7
  from rxn.reaction import Reaction
 
 
 
 
8
 
9
  PROMPT_DIR = "prompts/"
10
  ckpt_path = "./rxn/model/model.ckpt"
@@ -15,6 +19,7 @@ PROMPT_NAMES = {
15
  "2_RxnOCR.txt": "Reaction Image Parsing Workflow",
16
  }
17
  example_diagram = "examples/exp.png"
 
18
 
19
  def list_prompt_files_with_names():
20
  """
@@ -36,6 +41,7 @@ def parse_reactions(output_json):
36
  reactions_data = json.loads(output_json) # 转换 JSON 字符串为字典
37
  reactions_list = reactions_data.get("reactions", [])
38
  detailed_output = []
 
39
 
40
  for reaction in reactions_list:
41
  reaction_id = reaction.get("reaction_id", "Unknown ID")
@@ -50,6 +56,7 @@ def parse_reactions(output_json):
50
  ]
51
  products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
52
  products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
 
53
 
54
  # 构造反应的完整字符串,定制字体颜色
55
  full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {', '.join(conditions_1)}"
@@ -64,7 +71,12 @@ def parse_reactions(output_json):
64
  reaction_output += "<br>"
65
  detailed_output.append(reaction_output)
66
 
67
- return detailed_output
 
 
 
 
 
68
 
69
  def process_chem_image(image, selected_task):
70
  chem_mllm = RXNIM()
@@ -78,11 +90,12 @@ def process_chem_image(image, selected_task):
78
  rxnim_result = chem_mllm.process(image_path, prompt_path)
79
 
80
  # 将 JSON 结果解析为结构化输出
81
- detailed_reactions = parse_reactions(rxnim_result)
82
 
83
  # 调用 RxnScribe 模型处理并生成整合图像
84
  predictions = model.predict_image_file(image_path, molscribe=True, ocr=True)
85
  combined_image_path = generate_combined_image(predictions, image_path)
 
86
 
87
  json_file_path = "output.json"
88
  with open(json_file_path, "w") as json_file:
@@ -90,7 +103,7 @@ def process_chem_image(image, selected_task):
90
 
91
 
92
  # 返回详细反应和整合图像
93
- return "\n\n".join(detailed_reactions), combined_image_path, example_diagram, json_file_path
94
 
95
 
96
  # 获取 prompts 和友好名字
@@ -106,26 +119,111 @@ examples = [
106
  ]
107
 
108
  # 定义 Gradio 界面
109
- demo = gr.Interface(
110
- fn=process_chem_image,
111
- inputs=[
112
- gr.Image(type="pil", label="Upload Reaction Image"),
113
- gr.Radio(
114
- choices=list(prompts_with_names.keys()), # 显示任务名字
115
- label="Select a predefined task",
116
- ),
117
- ],
118
- outputs=[
119
- gr.HTML(label="Reaction outputs"),
120
- gr.Image(label="Visualization"), # 显示整合图像
121
- gr.Image(value=example_diagram, label="Schematic Diagram"),
122
- gr.File(label="Download JSON File"),
123
-
124
- ],
125
- title="Towards Large-scale Chemical Reaction Image Parsing via a Multimodal Large Language Model",
126
- description="Upload a reaction image and select a predefined task prompt.",
127
- examples=examples, # 使用嵌套列表作为示例
128
- examples_per_page=20,
129
- )
130
-
131
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from getReaction import generate_combined_image
6
  import torch
7
  from rxn.reaction import Reaction
8
+ from rdkit import Chem
9
+ from rdkit.Chem import rdChemReactions
10
+ from rdkit.Chem import Draw
11
+
12
 
13
  PROMPT_DIR = "prompts/"
14
  ckpt_path = "./rxn/model/model.ckpt"
 
19
  "2_RxnOCR.txt": "Reaction Image Parsing Workflow",
20
  }
21
  example_diagram = "examples/exp.png"
22
+ rdkit_image = "examples/image.webp"
23
 
24
  def list_prompt_files_with_names():
25
  """
 
41
  reactions_data = json.loads(output_json) # 转换 JSON 字符串为字典
42
  reactions_list = reactions_data.get("reactions", [])
43
  detailed_output = []
44
+ smiles_output = []
45
 
46
  for reaction in reactions_list:
47
  reaction_id = reaction.get("reaction_id", "Unknown ID")
 
56
  ]
57
  products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
58
  products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
59
+ products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])]
60
 
61
  # 构造反应的完整字符串,定制字体颜色
62
  full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {', '.join(conditions_1)}"
 
71
  reaction_output += "<br>"
72
  detailed_output.append(reaction_output)
73
 
74
+ reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}"
75
+ smiles_output.append(reaction_smiles)
76
+
77
+
78
+
79
+ return detailed_output, smiles_output
80
 
81
  def process_chem_image(image, selected_task):
82
  chem_mllm = RXNIM()
 
90
  rxnim_result = chem_mllm.process(image_path, prompt_path)
91
 
92
  # 将 JSON 结果解析为结构化输出
93
+ detailed_reactions, smiles_output = parse_reactions(rxnim_result)
94
 
95
  # 调用 RxnScribe 模型处理并生成整合图像
96
  predictions = model.predict_image_file(image_path, molscribe=True, ocr=True)
97
  combined_image_path = generate_combined_image(predictions, image_path)
98
+ #combined_image_path = model.draw_predictions(predictions, image_path)
99
 
100
  json_file_path = "output.json"
101
  with open(json_file_path, "w") as json_file:
 
103
 
104
 
105
  # 返回详细反应和整合图像
106
+ return "\n\n".join(detailed_reactions), smiles_output, combined_image_path, example_diagram, json_file_path
107
 
108
 
109
  # 获取 prompts 和友好名字
 
119
  ]
120
 
121
  # 定义 Gradio 界面
122
+ with gr.Blocks() as demo:
123
+ gr.Markdown(
124
+ """
125
+
126
+ <center> <h1>Towards Large-scale Chemical Reaction Image Parsing via a Multimodal Large Language Model<h1></center>
127
+
128
+ Upload a reaction image and select a predefined task prompt.
129
+ """)
130
+
131
+
132
+
133
+ # 上半部分,输入区域
134
+ with gr.Row(equal_height=False):
135
+ with gr.Column(scale=1): # 左侧列
136
+ image_input = gr.Image(type="pil", label="Upload Reaction Image")
137
+ task_radio = gr.Radio(
138
+ choices=list(prompts_with_names.keys()),
139
+ label="Select a predefined task",
140
+ )
141
+ with gr.Row(): # Clear 和 Submit 按钮放在同一行
142
+ clear_button = gr.Button("Clear")
143
+ process_button = gr.Button("Run", elem_id="submit-btn")
144
+
145
+ gr.Markdown("### Reaction Imge Parsing Output")
146
+ reaction_output = gr.HTML(label="Reaction outputs")
147
+
148
+
149
+ with gr.Column(scale=1):
150
+
151
+ gr.Markdown("### Reaction Extraction Output")
152
+ visualization_output = gr.Image(label="Visualization Output")
153
+ schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram")
154
+
155
+
156
+ with gr.Column(scale=1):
157
+ gr.Markdown("### Machine-readable Data Output")
158
+ smiles_output = gr.Textbox(
159
+ label="Reaction SMILES",
160
+ show_copy_button=True,
161
+ interactive=False,
162
+ visible=False,
163
+ )
164
+
165
+
166
+ # 下半部分,图像和 JSON 输出
167
+ @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑
168
+ def show_split(inputs): # 定义处理和展示分割文本的函数
169
+ if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空
170
+ return gr.Textbox(label= f"Reaction SMILES"), gr.Image(value=rdkit_image, label= "RDKit Image generated from Reaction SMILES")
171
+ else:
172
+ # 假设输入是逗号分隔的 SMILES 字符串
173
+ smiles_list = inputs.split(",")
174
+ smiles_list = [item.strip("[]' ") for item in smiles_list]
175
+ components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件
176
+ for i, smiles in enumerate(smiles_list):
177
+ smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
178
+ reaction = rdChemReactions.ReactionFromSmarts(smiles)
179
+ if reaction:
180
+ img = Draw.ReactionToImage(reaction)
181
+ components.append(gr.Textbox(value=smiles,label= f"Reaction {i + 1} SMILES", show_copy_button=True, interactive=False))
182
+ components.append(gr.Image(value=img,label= f"Reaction {i + 1} RDKit Image"))
183
+ return components # 返回包含所有 SMILES Textbox 组件的列表
184
+
185
+ download_json = gr.File(label="Download JSON File",)
186
+
187
+
188
+
189
+
190
+ # 示例部分
191
+ gr.Examples(
192
+ examples=examples,
193
+ inputs=[image_input, task_radio],
194
+ outputs=[reaction_output, smiles_output, visualization_output],
195
+ )
196
+
197
+ # 绑定功能
198
+ clear_button.click(
199
+ lambda: (None, None, None, None, None),
200
+ inputs=[],
201
+ outputs=[
202
+ image_input,
203
+ task_radio,
204
+ reaction_output,
205
+ smiles_output,
206
+ visualization_output,
207
+ ],
208
+ )
209
+
210
+ process_button.click(
211
+ process_chem_image,
212
+ inputs=[image_input, task_radio],
213
+ outputs=[
214
+ reaction_output,
215
+ smiles_output,
216
+ visualization_output,
217
+ schematic_diagram,
218
+ download_json,
219
+ ],
220
+ )
221
+
222
+ demo.css = """
223
+ #submit-btn {
224
+ background-color: #FF914D;
225
+ color: white;
226
+ font-weight: bold;
227
+ }
228
+ """
229
+ demo.launch()