iamomtiwari commited on
Commit
bb9954b
·
verified ·
1 Parent(s): 2af3b99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -16
app.py CHANGED
@@ -4,26 +4,34 @@ from PIL import Image
4
  import torch
5
 
6
  # Load the pre-trained ViT model and processor
7
- processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
8
- model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
9
 
10
  # Inference function for predicting with ViT model
11
  def predict(image):
12
- # Preprocess the input image using the processor
13
- inputs = processor(images=image, return_tensors="pt")
14
-
15
- # Get the model's predictions
16
- with torch.no_grad():
17
- outputs = model(**inputs)
18
- logits = outputs.logits
19
-
20
- # Get the predicted class index (class with the highest logit value)
21
- predicted_class_idx = logits.argmax(-1).item()
 
 
 
 
 
22
 
23
- # Get the human-readable label for the predicted class
24
- predicted_class_label = model.config.id2label[predicted_class_idx]
25
-
26
- return f"Predicted class: {predicted_class_label}"
 
 
 
27
 
28
  # Create Gradio Interface (Note the change here: `gr.Image` and `gr.Text`)
29
  interface = gr.Interface(fn=predict,
 
4
  import torch
5
 
6
  # Load the pre-trained ViT model and processor
7
+ processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') # Using the in21k pre-trained model
8
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
9
 
10
  # Inference function for predicting with ViT model
11
  def predict(image):
12
+ try:
13
+ # Ensure the image is in PIL format
14
+ if isinstance(image, str):
15
+ image = Image.open(image)
16
+
17
+ # Preprocess the input image using the processor
18
+ inputs = processor(images=image, return_tensors="pt")
19
+
20
+ # Get the model's predictions
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
+ logits = outputs.logits
24
+
25
+ # Get the predicted class index (class with the highest logit value)
26
+ predicted_class_idx = logits.argmax(-1).item()
27
 
28
+ # Get the human-readable label for the predicted class
29
+ predicted_class_label = model.config.id2label[predicted_class_idx]
30
+
31
+ return f"Predicted class: {predicted_class_label}"
32
+
33
+ except Exception as e:
34
+ return f"Error: {str(e)}"
35
 
36
  # Create Gradio Interface (Note the change here: `gr.Image` and `gr.Text`)
37
  interface = gr.Interface(fn=predict,