from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse import io import joblib import torch import numpy as np import torchvision.transforms as transforms from PIL import Image import yaml import traceback import timm import logging from fastapi.logger import logger app = FastAPI() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class_mapping = {'tb': 0, 'healthy': 1, 'sick_but_no_tb': 2} reverse_mapping = {v: k for k, v in class_mapping.items()} labels = list(class_mapping.keys()) def load_model(): # config = read_params(config_path) model = timm.create_model('convnext_base.clip_laiona', pretrained=True, num_classes=3) model_state_dict = torch.load('model.pth', map_location=device) model.load_state_dict(model_state_dict) model.eval() return model def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)).convert('RGB') return my_transforms(image).unsqueeze(0) def get_prediction(data): tensor = transform_image(data) # model = app.package['model'] with torch.no_grad(): prediction = model(tensor) prediction = reverse_mapping[prediction.argmax().item()] return prediction ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS # @app.get("/predict") # async def predict(file: UploadFile = File(...)): # """ # Perform prediction on the uploaded image # """ # logger.info('API predict called') # if not allowed_file(file.filename): # raise HTTPException(status_code=400, detail="Format not supported") # try: # img_bytes = await file.read() # class_name = get_prediction(img_bytes) # logger.info(f'Prediction: {class_name}') # return JSONResponse(content={"class_name": class_name}) # except Exception as e: # logger.error(f'Error: {str(e)}') # return JSONResponse(content={"error": str(e), "trace": traceback.format_exc()}, status_code=500) # # @app.get("/") # # def greet_json(): # # return {"Hello": "World!"} import torch import requests from PIL import Image from torchvision import transforms # model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval() # Download human-readable labels for ImageNet. # response = requests.get("https://git.io/JJkYN") # labels = response.text.split("\n") model = load_model() augs = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) def predict(inp): # inp = transforms.Resize((224, 224))(inp).transforms.ToTensor()(inp).unsqueeze(0) inp = augs(inp).unsqueeze(0) with torch.no_grad(): prediction = torch.nn.functional.softmax(model(inp)[0], dim=0) confidences = {labels[i]: float(prediction[i]) for i in range(3)} # prediction = reverse_mapping[prediction] return confidences import gradio as gr gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3)).launch()