Jangai commited on
Commit
0c2a5d3
·
verified ·
1 Parent(s): 6af9b95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import ViTImageProcessor, ViTForImageClassification
4
  from PIL import Image
5
  import numpy as np
6
  import logging
@@ -8,11 +8,11 @@ import logging
8
  # Configure logging
9
  logging.basicConfig(level=logging.DEBUG)
10
 
11
- # Load the pre-trained model and image processor
12
- model_name = "google/vit-base-patch16-224"
13
  logging.info("Loading image processor and model...")
14
- image_processor = ViTImageProcessor.from_pretrained(model_name)
15
- model = ViTForImageClassification.from_pretrained(model_name)
16
 
17
  # Define the prediction function
18
  def predict(image):
@@ -31,7 +31,7 @@ def predict(image):
31
  logging.debug("Image converted successfully.")
32
 
33
  logging.info("Processing image...")
34
- inputs = image_processor(images=image, return_tensors="pt")
35
  outputs = model(**inputs)
36
  logits = outputs.logits
37
  predicted_class_idx = logits.argmax(-1).item()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
4
  from PIL import Image
5
  import numpy as np
6
  import logging
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.DEBUG)
10
 
11
+ # Load the pre-trained model and feature extractor
12
+ model_name = "google/vit-base-patch16-224-in21k"
13
  logging.info("Loading image processor and model...")
14
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
15
+ model = AutoModelForImageClassification.from_pretrained(model_name)
16
 
17
  # Define the prediction function
18
  def predict(image):
 
31
  logging.debug("Image converted successfully.")
32
 
33
  logging.info("Processing image...")
34
+ inputs = feature_extractor(images=image, return_tensors="pt")
35
  outputs = model(**inputs)
36
  logits = outputs.logits
37
  predicted_class_idx = logits.argmax(-1).item()