parneetsingh022 commited on
Commit
1c7b13d
·
verified ·
1 Parent(s): 05d5fb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -47,21 +47,23 @@ model = CustomModel(input_shape=(3,128,128), num_classes=2)
47
  model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
48
 
49
 
50
- def greet(image):
 
 
 
51
  preprocess = transforms.Compose([
52
- transforms.Resize((128, 128)),
53
- transforms.ToTensor(),
54
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55
  ])
56
 
57
  classes = ['cat', 'dog']
58
 
59
  x = preprocess(image).unsqueeze(0)
60
  x = model(x)
61
- output = torch.nn.functional.softmax(x, dim=1)
62
-
63
 
64
  return classes[probabilities.argmax(dim=1).item()]
65
 
66
- demo = gr.Interface(fn=greet, inputs="image", outputs="text")
67
  demo.launch()
 
47
  model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
48
 
49
 
50
+ def predict(image):
51
+ # Convert NumPy array to PIL Image
52
+ image = Image.fromarray(np.uint8(image)).convert('RGB')
53
+
54
  preprocess = transforms.Compose([
55
+ transforms.Resize((128, 128)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
58
  ])
59
 
60
  classes = ['cat', 'dog']
61
 
62
  x = preprocess(image).unsqueeze(0)
63
  x = model(x)
64
+ probabilities = torch.nn.functional.softmax(x, dim=1)
 
65
 
66
  return classes[probabilities.argmax(dim=1).item()]
67
 
68
+ demo = gr.Interface(fn=predict, inputs="image", outputs="text")
69
  demo.launch()