import gradio as gr from transformers import pipeline # Load the model classifier = pipeline( "text-classification", model="ashishkgpian/biobert_icd9_classifier_ehr" ) def classify_symptoms(text): try: results = classifier(text, top_k=5) formatted_results = [] for result in results: formatted_results.append({ "ICD9 Code": result['label'], "Confidence": f"{result['score']:.2%}" }) return formatted_results except Exception as e: return f"Error processing classification: {str(e)}" # Enhanced CSS for a more professional medical look custom_css = """ .gradio-container { max-width: 1200px !important; margin: auto !important; padding: 2rem !important; background-color: #f0f4f7 !important; } #component-0 { text-align: center; padding: 1rem; margin-bottom: 2rem; background: #ffffff; border-radius: 10px; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); } #component-0 h1 { color: #2c3e50; font-size: 2.5rem; margin-bottom: 0.5rem; } #component-0 h3 { color: #34495e; font-size: 1.2rem; font-weight: normal; } .input-container { background: white !important; padding: 2rem !important; border-radius: 10px !important; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important; margin-bottom: 1.5rem !important; } .textbox textarea { border: 2px solid #3498db !important; border-radius: 8px !important; padding: 1rem !important; font-size: 1.1rem !important; min-height: 120px !important; } .output-container { background: white !important; padding: 2rem !important; border-radius: 10px !important; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important; } .examples-container { background: white !important; padding: 1.5rem !important; border-radius: 10px !important; margin-top: 1rem !important; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important; } .example-text { font-size: 0.9rem !important; color: #2c3e50 !important; padding: 0.5rem !important; border-radius: 4px !important; background: #f8f9fa !important; margin: 0.5rem 0 !important; } .footer { text-align: center; margin-top: 2rem; padding: 1rem; background: white; border-radius: 10px; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); } .label { font-size: 1.2rem !important; color: #2c3e50 !important; font-weight: 600 !important; margin-bottom: 0.5rem !important; } """ with gr.Blocks(css=custom_css) as demo: with gr.Row(): with gr.Column(): gr.Markdown( """ # 🏥 MedAI: Clinical Symptom ICD9 Classifier ### Advanced AI-Powered Diagnostic Code Assistant """ ) with gr.Row(): with gr.Column(elem_classes=["input-container"]): input_text = gr.Textbox( label="Clinical Symptom Description", placeholder="Enter detailed patient symptoms and clinical observations...", lines=5, elem_classes=["textbox"] ) with gr.Row(): output = gr.JSON( label="Suggested ICD9 Diagnostic Codes", elem_classes=["output-container"] ) with gr.Row(): with gr.Column(elem_classes=["examples-container"]): examples = gr.Examples( examples=[ ["45-year-old male experiencing severe chest pain, radiating to left arm, with shortness of breath and excessive sweating"], ["Persistent headache for 2 weeks, accompanied by dizziness and occasional blurred vision"], ["Diabetic patient reporting frequent urination, increased thirst, and unexplained weight loss"], ["Elderly patient with chronic knee pain, reduced mobility, and signs of inflammation"] ], inputs=input_text, label="Example Clinical Cases", elem_classes=["example-text"] ) input_text.submit(fn=classify_symptoms, inputs=input_text, outputs=output) with gr.Row(): gr.Markdown( """ """, ) if __name__ == "__main__": demo.launch()