dmusingu commited on
Commit
b2fca83
·
verified ·
1 Parent(s): bc7538f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py CHANGED
@@ -1,7 +1,78 @@
1
  from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
  def greet_json():
7
  return {"Hello": "World!"}
 
1
  from fastapi import FastAPI
2
+ from fastapi.responses import JSONResponse
3
+ import io
4
+ import joblib
5
+ import torch
6
+ import numpy as np
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ import yaml
10
+ import traceback
11
+ import timm
12
+ import logging
13
+ from fastapi.logger import logger
14
+
15
 
16
  app = FastAPI()
17
 
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ def load_model():
21
+ config = read_params(config_path)
22
+ model = timm.create_model('convnext_base.clip_laiona', pretrained=True, num_classes=3)
23
+ model_state_dict = torch.load('model.pth', map_location=device)
24
+ model.load_state_dict(model_state_dict)
25
+ model.eval()
26
+ return model
27
+
28
+ model = load_model()
29
+
30
+ def transform_image(image_bytes):
31
+ my_transforms = transforms.Compose([transforms.Resize(255),
32
+ transforms.CenterCrop(224),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(
35
+ [0.485, 0.456, 0.406],
36
+ [0.229, 0.224, 0.225])])
37
+ image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
38
+ return my_transforms(image).unsqueeze(0)
39
+
40
+
41
+ def get_prediction(data):
42
+ tensor = transform_image(data)
43
+ # model = app.package['model']
44
+ with torch.no_grad():
45
+ prediction = model(tensor)
46
+ prediction = reverse_mapping[prediction.argmax().item()]
47
+ return prediction
48
+
49
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
50
+
51
+ def allowed_file(filename):
52
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
53
+
54
+
55
+ @app.post("/predict")
56
+ async def predict(file: UploadFile = File(...)):
57
+ """
58
+ Perform prediction on the uploaded image
59
+ """
60
+
61
+ logger.info('API predict called')
62
+
63
+ if not allowed_file(file.filename):
64
+ raise HTTPException(status_code=400, detail="Format not supported")
65
+
66
+ try:
67
+ img_bytes = await file.read()
68
+ class_name = get_prediction(img_bytes)
69
+ logger.info(f'Prediction: {class_name}')
70
+ return JSONResponse(content={"class_name": class_name})
71
+ except Exception as e:
72
+ logger.error(f'Error: {str(e)}')
73
+ return JSONResponse(content={"error": str(e), "trace": traceback.format_exc()}, status_code=500)
74
+
75
+
76
  @app.get("/")
77
  def greet_json():
78
  return {"Hello": "World!"}