import gradio as gr import os from peft import PeftModel from PIL import Image import torch from transformers import AutoImageProcessor, AutoModelForImageClassification from torchvision.transforms import ( CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor, ) model_name = 'google/vit-large-patch16-224' adapter = 'monsoon-nlp/eyegazer-vit-binary' image_processor = AutoImageProcessor.from_pretrained(model_name) normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) train_transforms = Compose( [ RandomResizedCrop(image_processor.size["height"]), RandomHorizontalFlip(), ToTensor(), normalize, ] ) val_transforms = Compose( [ Resize(image_processor.size["height"]), CenterCrop(image_processor.size["height"]), ToTensor(), normalize, ] ) model = AutoModelForImageClassification.from_pretrained( model_name, ignore_mismatched_sizes=True, num_labels=2, ) lora_model = PeftModel.from_pretrained(model, adapter) def query(img): pimg = val_transforms(img.convert("RGB")) batch = pimg.unsqueeze(0) op = lora_model(batch) vals = op.logits.tolist()[0] if vals[0] > vals[1]: return "Predicted unaffected" else: return "Predicted affected to some degree" iface = gr.Interface( fn=query, examples=[ os.path.join(os.path.dirname(__file__), "images/i1.png"), os.path.join(os.path.dirname(__file__), "images/0a09aa7356c0.png"), os.path.join(os.path.dirname(__file__), "images/0a4e1a29ffff.png"), os.path.join(os.path.dirname(__file__), "images/0c43c79e8cfb.png"), os.path.join(os.path.dirname(__file__), "images/0c7e82daf5a0.png"), ], inputs=[ gr.inputs.Image( image_mode='RGB', sources=['upload', 'clipboard'], type='pil', label='Input Fundus Camera Image', show_label=True, ), ], outputs=[ gr.Markdown(value="", label="Predicted label"), ], title="ViT retinopathy model", description="Diabetic retinopathy model trained on APTOS 2019 dataset; demonstration, not medical dvice", allow_flagging="never", ) iface.launch()