timm
/

rwightman HF staff commited on
Commit
9179d15
1 Parent(s): c85ba24

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -6
README.md CHANGED
@@ -28,17 +28,18 @@ import torch
28
  import torch.nn.functional as F
29
  from urllib.request import urlopen
30
  from PIL import Image
31
- from open_clip import create_model_from_pretrained, get_tokenizer
32
 
33
- model, preprocess = create_model_from_pretrained('hf-hub:ViT-SO400M-14-SigLIP')
34
- tokenizer = get_tokenizer('hf-hub:ViT-SO400M-14-SigLIP')
35
 
36
  image = Image.open(urlopen(
37
  'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
38
  ))
39
  image = preprocess(image).unsqueeze(0)
40
 
41
- text = tokenizer(["a diagram", "a dog", "a cat", "a beignet"], context_length=model.context_length)
 
42
 
43
  with torch.no_grad(), torch.cuda.amp.autocast():
44
  image_features = model.encode_image(image)
@@ -46,9 +47,10 @@ with torch.no_grad(), torch.cuda.amp.autocast():
46
  image_features = F.normalize(image_features, dim=-1)
47
  text_features = F.normalize(text_features, dim=-1)
48
 
49
- text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
50
 
51
- print("Label probs:", text_probs) # prints: [[0., 0., 0., 1.0]]
 
52
  ```
53
 
54
  ### With `timm` (for image embeddings)
 
28
  import torch.nn.functional as F
29
  from urllib.request import urlopen
30
  from PIL import Image
31
+ from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
32
 
33
+ model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP')
34
+ tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP')
35
 
36
  image = Image.open(urlopen(
37
  'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
38
  ))
39
  image = preprocess(image).unsqueeze(0)
40
 
41
+ labels_list = ["a dog", "a cat", "a donut", "a beignet"]
42
+ text = tokenizer(labels_list, context_length=model.context_length)
43
 
44
  with torch.no_grad(), torch.cuda.amp.autocast():
45
  image_features = model.encode_image(image)
 
47
  image_features = F.normalize(image_features, dim=-1)
48
  text_features = F.normalize(text_features, dim=-1)
49
 
50
+ text_probs = torch.sigmoid(image_features @ text_features.T * model.logit_scale.exp() + model.logit_bias)
51
 
52
+ zipped_list = list(zip(labels_list, [round(p.item(), 3) for p in text_probs[0]]))
53
+ print("Label probabilities: ", zipped_list)
54
  ```
55
 
56
  ### With `timm` (for image embeddings)