import os import gradio as gr import json from rxnim import RXNIM from getReaction import generate_combined_image import torch from rxn.reaction import Reaction from rdkit import Chem from rdkit.Chem import rdChemReactions from rdkit.Chem import Draw PROMPT_DIR = "prompts/" ckpt_path = "./rxn/model/model.ckpt" model = Reaction(ckpt_path, device=torch.device('cpu')) # 定义 prompt 文件名到友好名字的映射 PROMPT_NAMES = { "2_RxnOCR.txt": "Reaction Image Parsing Workflow", } example_diagram = "examples/exp.png" rdkit_image = "examples/image.webp" def list_prompt_files_with_names(): """ 列出 prompts 目录下的所有 .txt 文件,为没有名字的生成默认名字。 返回 {friendly_name: filename} 映射。 """ prompt_files = {} for f in os.listdir(PROMPT_DIR): if f.endswith(".txt"): # 如果文件名有预定义的名字,使用预定义名字 friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}") prompt_files[friendly_name] = f return prompt_files def parse_reactions(output_json): """ 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。 """ reactions_data = json.loads(output_json) # 转换 JSON 字符串为字典 reactions_list = reactions_data.get("reactions", []) detailed_output = [] smiles_output = [] for reaction in reactions_list: reaction_id = reaction.get("reaction_id", "Unknown ID") reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])] conditions = [ f"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]" for c in reaction.get("conditions", []) ] conditions_1 = [ f"{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]" for c in reaction.get("conditions", []) ] products = [f"{p.get('smiles', 'Unknown')}" for p in reaction.get("products", [])] products_1 = [f"{p.get('smiles', 'Unknown')}" for p in reaction.get("products", [])] products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])] # 构造反应的完整字符串,定制字体颜色 full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {', '.join(conditions_1)}" full_reaction = f"{full_reaction}" # 详细反应格式化输出 reaction_output = f"Reaction: {reaction_id}
" reaction_output += f" Reactants: {', '.join(reactants)}
" reaction_output += f" Conditions: {', '.join(conditions)}
" reaction_output += f" Products: {', '.join(products)}
" reaction_output += f" Full Reaction: {full_reaction}
" reaction_output += "
" detailed_output.append(reaction_output) reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}" smiles_output.append(reaction_smiles) return detailed_output, smiles_output def process_chem_image(image, selected_task): chem_mllm = RXNIM() # 将友好名字转换为实际文件名 prompt_path = os.path.join(PROMPT_DIR, prompts_with_names[selected_task]) image_path = "temp_image.png" image.save(image_path) # 调用 RXNIM 处理 rxnim_result = chem_mllm.process(image_path, prompt_path) # 将 JSON 结果解析为结构化输出 detailed_reactions, smiles_output = parse_reactions(rxnim_result) # 调用 RxnScribe 模型处理并生成整合图像 predictions = model.predict_image_file(image_path, molscribe=True, ocr=True) combined_image_path = generate_combined_image(predictions, image_path) #combined_image_path = model.draw_predictions(predictions, image_path) json_file_path = "output.json" with open(json_file_path, "w") as json_file: json.dump(json.loads(rxnim_result), json_file, indent=4) # 返回详细反应和整合图像 return "\n\n".join(detailed_reactions), smiles_output, combined_image_path, example_diagram, json_file_path prompts_with_names = list_prompt_files_with_names() examples = [ ["examples/reaction1.png", "Reaction Image Parsing Workflow"], ["examples/reaction2.png", "Reaction Image Parsing Workflow"], ["examples/reaction3.png", "Reaction Image Parsing Workflow"], ["examples/reaction4.png", "Reaction Image Parsing Workflow"], ] # 定义 Gradio 界面 with gr.Blocks() as demo: gr.Markdown( """

Towards Large-scale Chemical Reaction Image Parsing via a Multimodal Large Language Model

Upload a reaction image and select a predefined task prompt. """) # 上半部分,输入区域 with gr.Row(equal_height=False): with gr.Column(scale=1): # 左侧列 image_input = gr.Image(type="pil", label="Upload Reaction Image") task_radio = gr.Radio( choices=list(prompts_with_names.keys()), label="Select a predefined task", ) with gr.Row(): # Clear 和 Submit 按钮放在同一行 clear_button = gr.Button("Clear") process_button = gr.Button("Run", elem_id="submit-btn") gr.Markdown("### Reaction Imge Parsing Output") reaction_output = gr.HTML(label="Reaction outputs") with gr.Column(scale=1): gr.Markdown("### Reaction Extraction Output") visualization_output = gr.Image(label="Visualization Output") schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram") with gr.Column(scale=1): gr.Markdown("### Machine-readable Data Output") smiles_output = gr.Textbox( label="Reaction SMILES", show_copy_button=True, interactive=False, visible=False, ) # 下半部分,图像和 JSON 输出 @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑 def show_split(inputs): # 定义处理和展示分割文本的函数 if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空 return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i") else: # 假设输入是逗号分隔的 SMILES 字符串 smiles_list = inputs.split(",") smiles_list = [item.strip("[]' ") for item in smiles_list] components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件 for i, smiles in enumerate(smiles_list): smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "") reaction = rdChemReactions.ReactionFromSmarts(smiles) if reaction: img = Draw.ReactionToImage(reaction) components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i + 1} ", show_copy_button=True, interactive=False)) components.append(gr.Image(value=img,label= f"RDKit Image of Reaction {i + 1} ")) return components # 返回包含所有 SMILES Textbox 组件的列表 download_json = gr.File(label="Download JSON File") # 示例部分 gr.Examples( examples=examples, inputs=[image_input, task_radio], outputs=[reaction_output, smiles_output, visualization_output], ) # 绑定功能 clear_button.click( lambda: (None, None, None, None, None), inputs=[], outputs=[ image_input, task_radio, reaction_output, smiles_output, visualization_output, ], ) process_button.click( process_chem_image, inputs=[image_input, task_radio], outputs=[ reaction_output, smiles_output, visualization_output, schematic_diagram, download_json, ], ) demo.css = """ #submit-btn { background-color: #FF914D; color: white; font-weight: bold; } """ demo.launch()