Spaces:
Sleeping
Sleeping
File size: 3,527 Bytes
717e49e b2fca83 2736466 b2fca83 0e362ce 190f7e2 0e362ce b2fca83 823b2fe b2fca83 0e362ce b2fca83 0e362ce b2fca83 0e362ce b2fca83 0e362ce b2fca83 0e362ce afb450b 0e362ce 2d3a8bd 0e362ce 2d3a8bd 0e362ce 2cb721d 0e362ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import io
import joblib
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import yaml
import traceback
import timm
import logging
from fastapi.logger import logger
app = FastAPI()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_mapping = {'tb': 0, 'healthy': 1, 'sick_but_no_tb': 2}
reverse_mapping = {v: k for k, v in class_mapping.items()}
labels = list(class_mapping.keys())
def load_model():
# config = read_params(config_path)
model = timm.create_model('convnext_base.clip_laiona', pretrained=True, num_classes=3)
model_state_dict = torch.load('model.pth', map_location=device)
model.load_state_dict(model_state_dict)
model.eval()
return model
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
return my_transforms(image).unsqueeze(0)
def get_prediction(data):
tensor = transform_image(data)
# model = app.package['model']
with torch.no_grad():
prediction = model(tensor)
prediction = reverse_mapping[prediction.argmax().item()]
return prediction
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
# @app.get("/predict")
# async def predict(file: UploadFile = File(...)):
# """
# Perform prediction on the uploaded image
# """
# logger.info('API predict called')
# if not allowed_file(file.filename):
# raise HTTPException(status_code=400, detail="Format not supported")
# try:
# img_bytes = await file.read()
# class_name = get_prediction(img_bytes)
# logger.info(f'Prediction: {class_name}')
# return JSONResponse(content={"class_name": class_name})
# except Exception as e:
# logger.error(f'Error: {str(e)}')
# return JSONResponse(content={"error": str(e), "trace": traceback.format_exc()}, status_code=500)
# # @app.get("/")
# # def greet_json():
# # return {"Hello": "World!"}
import torch
import requests
from PIL import Image
from torchvision import transforms
# model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
# Download human-readable labels for ImageNet.
# response = requests.get("https://git.io/JJkYN")
# labels = response.text.split("\n")
model = load_model()
augs = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def predict(inp):
# inp = transforms.Resize((224, 224))(inp).transforms.ToTensor()(inp).unsqueeze(0)
inp = augs(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(3)}
# prediction = reverse_mapping[prediction]
return confidences
import gradio as gr
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3)).launch() |