Jangai commited on
Commit
ca6e9aa
·
verified ·
1 Parent(s): 73d50df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import ViTFeatureExtractor, ViTForImageClassification
4
  from PIL import Image
5
 
6
  # Load the pre-trained model and feature extractor
7
  model_name = "google/vit-base-patch16-224"
8
- feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
9
  model = ViTForImageClassification.from_pretrained(model_name)
10
 
11
  # Define the prediction function
12
  def predict(image):
13
- inputs = feature_extractor(images=image, return_tensors="pt")
14
  outputs = model(**inputs)
15
  logits = outputs.logits
16
  predicted_class_idx = logits.argmax(-1).item()
@@ -19,8 +19,8 @@ def predict(image):
19
  # Create the Gradio interface
20
  iface = gr.Interface(
21
  fn=predict,
22
- inputs=gr.inputs.Sketchpad(),
23
- outputs=gr.outputs.Label(),
24
  title="Drawing Classifier",
25
  description="Draw something and the model will try to identify it!"
26
  )
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import ViTImageProcessor, ViTForImageClassification
4
  from PIL import Image
5
 
6
  # Load the pre-trained model and feature extractor
7
  model_name = "google/vit-base-patch16-224"
8
+ image_processor = ViTImageProcessor.from_pretrained(model_name)
9
  model = ViTForImageClassification.from_pretrained(model_name)
10
 
11
  # Define the prediction function
12
  def predict(image):
13
+ inputs = image_processor(images=image, return_tensors="pt")
14
  outputs = model(**inputs)
15
  logits = outputs.logits
16
  predicted_class_idx = logits.argmax(-1).item()
 
19
  # Create the Gradio interface
20
  iface = gr.Interface(
21
  fn=predict,
22
+ inputs=gr.components.Sketchpad(),
23
+ outputs=gr.components.Label(),
24
  title="Drawing Classifier",
25
  description="Draw something and the model will try to identify it!"
26
  )