dmusingu commited on
Commit
0e362ce
·
verified ·
1 Parent(s): b3b9f53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -19
app.py CHANGED
@@ -17,6 +17,10 @@ app = FastAPI()
17
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
 
 
 
 
20
  def load_model():
21
  # config = read_params(config_path)
22
  model = timm.create_model('convnext_base.clip_laiona', pretrained=True, num_classes=3)
@@ -52,27 +56,56 @@ def allowed_file(filename):
52
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
53
 
54
 
55
- @app.get("/predict")
56
- async def predict(file: UploadFile = File(...)):
57
- """
58
- Perform prediction on the uploaded image
59
- """
60
 
61
- logger.info('API predict called')
62
 
63
- if not allowed_file(file.filename):
64
- raise HTTPException(status_code=400, detail="Format not supported")
65
 
66
- try:
67
- img_bytes = await file.read()
68
- class_name = get_prediction(img_bytes)
69
- logger.info(f'Prediction: {class_name}')
70
- return JSONResponse(content={"class_name": class_name})
71
- except Exception as e:
72
- logger.error(f'Error: {str(e)}')
73
- return JSONResponse(content={"error": str(e), "trace": traceback.format_exc()}, status_code=500)
74
 
75
 
76
- # @app.get("/")
77
- # def greet_json():
78
- # return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
+
21
+ class_mapping = {'tb': 0, 'healthy': 1, 'sick_but_no_tb': 2}
22
+ reverse_mapping = {v: k for k, v in class_mapping.items()}
23
+
24
  def load_model():
25
  # config = read_params(config_path)
26
  model = timm.create_model('convnext_base.clip_laiona', pretrained=True, num_classes=3)
 
56
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
57
 
58
 
59
+ # @app.get("/predict")
60
+ # async def predict(file: UploadFile = File(...)):
61
+ # """
62
+ # Perform prediction on the uploaded image
63
+ # """
64
 
65
+ # logger.info('API predict called')
66
 
67
+ # if not allowed_file(file.filename):
68
+ # raise HTTPException(status_code=400, detail="Format not supported")
69
 
70
+ # try:
71
+ # img_bytes = await file.read()
72
+ # class_name = get_prediction(img_bytes)
73
+ # logger.info(f'Prediction: {class_name}')
74
+ # return JSONResponse(content={"class_name": class_name})
75
+ # except Exception as e:
76
+ # logger.error(f'Error: {str(e)}')
77
+ # return JSONResponse(content={"error": str(e), "trace": traceback.format_exc()}, status_code=500)
78
 
79
 
80
+ # # @app.get("/")
81
+ # # def greet_json():
82
+ # # return {"Hello": "World!"}
83
+
84
+
85
+ import torch
86
+ import requests
87
+ from PIL import Image
88
+ from torchvision import transforms
89
+
90
+ # model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
91
+
92
+
93
+ # Download human-readable labels for ImageNet.
94
+ # response = requests.get("https://git.io/JJkYN")
95
+ # labels = response.text.split("\n")
96
+
97
+
98
+ def predict(inp):
99
+ inp = transforms.ToTensor()(inp).unsqueeze(0)
100
+ with torch.no_grad():
101
+ prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
102
+ # confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
103
+ prediction = reverse_mapping[prediction]
104
+ return prediction
105
+
106
+
107
+ import gradio as gr
108
+
109
+ gr.Interface(fn=predict,
110
+ inputs=gr.Image(type="pil"),
111
+ outputs=gr.Label(num_top_classes=3)).launch()