TadsHugger commited on
Commit
5c64cba
1 Parent(s): 0acba1d

Create classifier

Browse files
Files changed (1) hide show
  1. classifier +16 -0
classifier ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
2
+ from PIL import Image
3
+ import torch
4
+
5
+ # Loading in Model
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ model = ViTForImageClassification.from_pretrained( "imjeffhi/pokemon_classifier").to(device)
8
+ feature_extractor = ViTFeatureExtractor.from_pretrained('imjeffhi/pokemon_classifier')
9
+
10
+ # Caling the model on a test image
11
+ img = Image.open('test.jpg')
12
+ extracted = feature_extractor(images=img, return_tensors='pt').to(device)
13
+ predicted_id = model(**extracted).logits.argmax(-1).item()
14
+ predicted_pokemon = model.config.id2label[predicted_id]
15
+
16
+ print('Predicted Pokemon:', predicted_pokemon)