parneetsingh022 commited on
Commit
59d9c25
·
verified ·
1 Parent(s): b87a427

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -49,22 +49,28 @@ model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
49
 
50
 
51
  def predict(image):
52
- # Convert NumPy array to PIL Image
53
- image = Image.fromarray(np.uint8(image)).convert('RGB')
54
-
55
  preprocess = transforms.Compose([
56
  transforms.Resize((128, 128)),
57
  transforms.ToTensor(),
58
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
59
  ])
60
 
61
- classes = ['cat', 'dog']
 
62
 
63
  x = preprocess(image).unsqueeze(0)
64
- x = model(x)
65
- probabilities = torch.nn.functional.softmax(x, dim=1)
66
-
67
- return classes[probabilities.argmax(dim=1).item()]
68
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  demo = gr.Interface(fn=predict, inputs="image", outputs="text")
70
  demo.launch()
 
49
 
50
 
51
  def predict(image):
 
 
 
52
  preprocess = transforms.Compose([
53
  transforms.Resize((128, 128)),
54
  transforms.ToTensor(),
55
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
56
  ])
57
 
58
+ # Ensure the image is a PIL Image
59
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
60
 
61
  x = preprocess(image).unsqueeze(0)
 
 
 
 
62
 
63
+ # Set model to evaluation mode
64
+ model.eval()
65
+
66
+ with torch.no_grad(): # Use no_grad context for inference to save memory and computations
67
+ x = model(x)
68
+ probabilities = torch.nn.functional.softmax(x, dim=1)
69
+ class_id = probabilities.argmax(dim=1).item()
70
+
71
+ classes = ['cat', 'dog']
72
+ return classes[class_id]
73
+
74
+ # Update Gradio interface
75
  demo = gr.Interface(fn=predict, inputs="image", outputs="text")
76
  demo.launch()