from transformers import AutoTokenizer, TFAutoModelForSequenceClassification import tensorflow as tf import gradio as gr # Load the tokenizer and model model_name = "Zabihin/Symptom_to_Diagnosis" tokenizer = AutoTokenizer.from_pretrained(model_name) model = TFAutoModelForSequenceClassification.from_pretrained(model_name) # Clean the input text def clean_input(symptom_text): # Remove unwanted characters or non-ASCII characters symptom_text = ''.join([c for c in symptom_text if ord(c) < 128]) symptom_text = symptom_text.lower() # Optional: Convert to lowercase return symptom_text # Define the predict function def predict(symptom_text, chat_history=[]): try: # Clean the input symptom_text = clean_input(symptom_text) # Tokenize the input inputs = tokenizer(symptom_text, return_tensors="tf", padding=True, truncation=True, max_length=512) # Get model output outputs = model(**inputs) logits = outputs.logits prediction = tf.argmax(logits, axis=-1).numpy()[0] # Map the prediction to a label labels = { 0: "Allergy", 1: "Arthritis", 2: "Bronchial Asthma", 3: "Cervical Spondylosis", 4: "Chicken Pox", 5: "Common Cold", 6: "Dengue", 7: "Diabetes", 8: "Drug Reaction", 9: "Fungal Infection", 10: "Gastroesophageal Reflux Disease", 11: "Hypertension", 12: "Impetigo", 13: "Jaundice", 14: "Malaria", 15: "Migraine", 16: "Peptic Ulcer Disease", 17: "Pneumonia", 18: "Psoriasis", 19: "Typhoid", 20: "Urinary Tract Infection", 21: "Varicose Veins" } descriptions = { "Allergy": "An immune system reaction to foreign substances.", "Arthritis": "Inflammation of one or more joints.", "Bronchial Asthma": "A condition where the airways become inflamed and narrow.", "Cervical Spondylosis": "Age-related changes in the bones, discs, and joints of the neck.", "Chicken Pox": "A highly contagious viral infection causing an itchy skin rash.", "Common Cold": "A viral infection of the upper respiratory tract, causing sneezing, runny nose, and sore throat.", "Dengue": "A viral disease transmitted by mosquitoes, causing fever and severe pain.", "Diabetes": "A disease that affects how your body processes blood sugar.", "Drug Reaction": "An adverse response to a medication.", "Fungal Infection": "An infection caused by fungi affecting the skin or organs.", "Gastroesophageal Reflux Disease": "A chronic digestive condition where stomach acid irritates the food pipe.", "Hypertension": "High blood pressure that can lead to heart disease.", "Impetigo": "A contagious bacterial skin infection.", "Jaundice": "A yellowing of the skin or eyes due to liver disease.", "Malaria": "A serious disease transmitted by mosquito bites, causing fever and chills.", "Migraine": "Severe headaches often accompanied by nausea and sensitivity to light.", "Peptic Ulcer Disease": "Sores in the stomach lining or the upper part of the small intestine.", "Pneumonia": "An infection that inflames the air sacs in one or both lungs.", "Psoriasis": "A chronic autoimmune disease causing the rapid growth of skin cells.", "Typhoid": "A bacterial infection causing high fever, abdominal pain, and weakness.", "Urinary Tract Infection": "An infection in any part of the urinary system.", "Varicose Veins": "Swollen, twisted veins caused by faulty valves in the veins." } diagnosis = labels.get(prediction, "Unknown diagnosis") description = descriptions.get(diagnosis, "No description available.") # Add conversation history chat_history.append(("User", symptom_text)) chat_history.append(("AI", f"Predicted Diagnosis: {diagnosis}. {description} Please consult a doctor for more accurate results.")) except Exception as e: chat_history.append(("AI", f"Error: {str(e)}")) return chat_history, "" # Gradio UI with gr.Blocks() as interface: gr.Markdown("""

Medi Mind - Your AI Health Assistant

""") chatbot = gr.Chatbot() input_box = gr.Textbox(show_label=False, placeholder="Describe your symptoms here...") send_button = gr.Button("Send") input_box.submit(predict, [input_box, chatbot], [chatbot, input_box]) send_button.click(predict, [input_box, chatbot], [chatbot, input_box]) if __name__ == "__main__": interface.launch(share=True, server_name="0.0.0.0", server_port=7860, debug=True)