kaveh commited on
Commit
c877192
1 Parent(s): 0453a39

added example applications

Browse files
Files changed (1) hide show
  1. README.md +95 -1
README.md CHANGED
@@ -35,6 +35,100 @@ It achieves the following results on the evaluation set:
35
  Here is the heatmap of the similarity score of the first 30 samples on the test split of the ROCO dataset of images vs their captions:
36
  ![heatmap](https://imgur.com/fPFM694.png)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ### Training hyperparameters
39
 
40
  The following hyperparameters were used during training:
@@ -98,7 +192,7 @@ The following hyperparameters were used during training:
98
  | 0.0974 | 4.13 | 22500 | 0.3388 |
99
 
100
 
101
- ### Framework versions
102
 
103
  - Transformers 4.31.0.dev0
104
  - Pytorch 2.0.1+cu117
 
35
  Here is the heatmap of the similarity score of the first 30 samples on the test split of the ROCO dataset of images vs their captions:
36
  ![heatmap](https://imgur.com/fPFM694.png)
37
 
38
+ ## Applications
39
+
40
+ ### Image Retrieval
41
+ This model can be utilized for image retrieval purposes, as demonstrated below:
42
+
43
+ #### Save Image Embeddings
44
+ ```
45
+ from PIL import Image
46
+ import pickle, os, torch
47
+ from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
48
+
49
+ # load model
50
+ model = VisionTextDualEncoderModel.from_pretrained("kaveh/rclip")
51
+ processor = VisionTextDualEncoderProcessor.from_pretrained("kaveh/rclip")
52
+
53
+ # TO-DO
54
+ images_path = "/path/to/images/"
55
+ images = [os.path.join(images_path,i) for i in os.listdir(images_path) if i.endswith(".jpg")]
56
+
57
+ # generate embeddings of images in your dataset
58
+ image_embeds = []
59
+ for img in images:
60
+ with torch.no_grad():
61
+ inputs = processor(text=None, images=Image.open(img), return_tensors="pt", padding=True)
62
+ outputs = model.get_image_features(**inputs)[0].numpy()
63
+ image_embeds.append(outputs)
64
+
65
+ # save images embeddings in a pickle file
66
+ with open("embeddings.pkl", 'wb') as f:
67
+ pickle.dump(np.array(image_embeds), f)
68
+ ```
69
+ #### Query for Images
70
+ ```
71
+ import numpy as np
72
+ from sklearn.metrics.pairwise import cosine_similarity
73
+ from PIL import Image
74
+ import pickle
75
+ import torch
76
+ from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
77
+
78
+ # search a query in embeddings
79
+ query = "Chest X-Ray photos"
80
+
81
+ # embed the query
82
+ inputs = processor(text=query, images=None, return_tensors="pt", padding=True)
83
+ with torch.no_grad():
84
+ query_embedding = model.get_text_features(**inputs)[0].numpy()
85
+
86
+ # load image embeddings
87
+ with open("embeddings.pkl", 'rb') as f:
88
+ image_embeds = pickle.load(f)
89
+
90
+ # find similar images indices
91
+ def find_k_similar_images(query_embedding, image_embeds, k=2):
92
+ similarities = cosine_similarity(query_embedding.reshape(1, -1), image_embeds)
93
+ closest_indices = np.argsort(similarities[0])[::-1][:k]
94
+ return closest_indices
95
+ similar_image_indices = find_k_similar_images(query_embedding, image_embeds, k=k)
96
+
97
+ # TO-DO
98
+ images_path = "/path/to/images/"
99
+ images = [os.path.join(images_path,i) for i in os.listdir(images_path) if i.endswith(".jpg")]
100
+
101
+ # get image paths
102
+ similar_image_names = [images[index] for index in similar_image_indices]
103
+ Image.open(similar_image_names[0])
104
+ ```
105
+
106
+ ### Zero-Shot Image Classification
107
+
108
+ This model can be effectively employed for zero-shot image classification, as exemplified below:
109
+ ```
110
+ import requests
111
+ from PIL import Image
112
+ import matplotlib.pyplot as plt
113
+
114
+ from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
115
+
116
+ model = VisionTextDualEncoderModel.from_pretrained("kaveh/rclip")
117
+ processor = VisionTextDualEncoderProcessor.from_pretrained("kaveh/rclip")
118
+
119
+ url = "https://huggingface.co/spaces/kaveh/radiology-image-retrieval/resolve/main/images/ROCO_09402.jpg"
120
+ image = Image.open(requests.get(url, stream=True).raw)
121
+ possible_class_names = ["Chest X-Ray", "Brain MRI", "Abdominal CT Scan", "Ultrasound", "OPG"]
122
+
123
+ inputs = processor(text=possible_class_names, images=image, return_tensors="pt", padding=True)
124
+ probs = model(**inputs).logits_per_image.softmax(dim=1).squeeze()
125
+
126
+ print("".join([x[0] + ": " + x[1] + "\n" for x in zip(possible_class_names, [format(prob, ".4%") for prob in probs])]))
127
+ image
128
+ ```
129
+
130
+ ## Training info
131
+
132
  ### Training hyperparameters
133
 
134
  The following hyperparameters were used during training:
 
192
  | 0.0974 | 4.13 | 22500 | 0.3388 |
193
 
194
 
195
+ ## Framework versions
196
 
197
  - Transformers 4.31.0.dev0
198
  - Pytorch 2.0.1+cu117