import torch import torch.nn as nn from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel from PIL import Image import gradio as gr # Model definition and setup class VisionLanguageModel(nn.Module): def __init__(self): super(VisionLanguageModel, self).__init__() self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') self.language_model = BertModel.from_pretrained('bert-base-uncased') self.classifier = nn.Linear( self.vision_model.config.hidden_size + self.language_model.config.hidden_size, 2 # Number of classes: benign or malignant ) def forward(self, input_ids, attention_mask, pixel_values): vision_outputs = self.vision_model(pixel_values=pixel_values) vision_pooled_output = vision_outputs.pooler_output language_outputs = self.language_model( input_ids=input_ids, attention_mask=attention_mask ) language_pooled_output = language_outputs.pooler_output combined_features = torch.cat( (vision_pooled_output, language_pooled_output), dim=1 ) logits = self.classifier(combined_features) return logits model = VisionLanguageModel() model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True)) model.eval() tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') def predict(image, text_input): image = feature_extractor(images=image, return_tensors="pt").pixel_values encoding = tokenizer( text_input, add_special_tokens=True, max_length=256, padding='max_length', truncation=True, return_tensors='pt' ) with torch.no_grad(): outputs = model( input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], pixel_values=image ) _, prediction = torch.max(outputs, dim=1) return prediction.item() # 1 for Malignant, 0 for Benign # Enhanced UI with black text with gr.Blocks(css=""" body { color: black; } .benign, .malignant { background-color: white; border: 1px solid lightgray; padding: 10px; border-radius: 5px; color: black; } .benign.correct, .malignant.correct { background-color: lightgreen; color: black; } """) as demo: gr.Markdown( """ # 🩺 SKIN LESION CLASSIFICATION Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant. """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Skin Lesion Image") text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)") with gr.Column(scale=1): gr.Markdown("## PREDICTION RESULTS") benign_output = gr.HTML("