File size: 3,537 Bytes
717e49e
b2fca83
 
 
 
 
 
 
 
 
 
 
 
 
2736466
 
 
b2fca83
 
0e362ce
 
 
190f7e2
0e362ce
b2fca83
823b2fe
b2fca83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e362ce
 
 
 
 
b2fca83
0e362ce
b2fca83
0e362ce
 
b2fca83
0e362ce
 
 
 
 
 
 
 
b2fca83
 
0e362ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afb450b
0e362ce
2d3a8bd
 
 
 
 
0e362ce
2d3a8bd
 
0e362ce
 
2cb721d
 
 
0e362ce
 
 
 
 
 
0dfcb33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import io
import joblib
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import yaml
import traceback
import timm
import logging
from fastapi.logger import logger


app = FastAPI()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class_mapping = {'tb': 0, 'healthy': 1, 'sick_but_no_tb': 2}
reverse_mapping = {v: k for k, v in class_mapping.items()}
labels = list(class_mapping.keys())

def load_model():
    # config = read_params(config_path)
    model = timm.create_model('convnext_base.clip_laiona', pretrained=True, num_classes=3)
    model_state_dict = torch.load('model.pth', map_location=device)
    model.load_state_dict(model_state_dict)
    model.eval()
    return model


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
    return my_transforms(image).unsqueeze(0)


def get_prediction(data):
    tensor = transform_image(data)
    # model = app.package['model']
    with torch.no_grad():
        prediction = model(tensor)
    prediction = reverse_mapping[prediction.argmax().item()]
    return prediction

ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS


# @app.get("/predict")
# async def predict(file: UploadFile = File(...)):
#     """
#     Perform prediction on the uploaded image
#     """

#     logger.info('API predict called')

#     if not allowed_file(file.filename):
#         raise HTTPException(status_code=400, detail="Format not supported")
    
#     try:
#         img_bytes = await file.read()
#         class_name = get_prediction(img_bytes)
#         logger.info(f'Prediction: {class_name}')
#         return JSONResponse(content={"class_name": class_name})
#     except Exception as e:
#         logger.error(f'Error: {str(e)}')
#         return JSONResponse(content={"error": str(e), "trace": traceback.format_exc()}, status_code=500)
     

# # @app.get("/")
# # def greet_json():
# #     return {"Hello": "World!"}


import torch
import requests
from PIL import Image
from torchvision import transforms

# model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()


# Download human-readable labels for ImageNet.
# response = requests.get("https://git.io/JJkYN")
# labels = response.text.split("\n")

model = load_model()

augs = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

def predict(inp):
  # inp = transforms.Resize((224, 224))(inp).transforms.ToTensor()(inp).unsqueeze(0)
  inp = augs(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
    confidences = {labels[i]: float(prediction[i]) for i in range(3)}
    # prediction = reverse_mapping[prediction]
  return confidences 


import gradio as gr

gr.Interface(fn=predict,
             inputs=gr.Image(type="pil"),
             outputs=gr.Label(num_top_classes=3)).launch(share=True)