added example applications
Browse files
README.md
CHANGED
@@ -35,6 +35,100 @@ It achieves the following results on the evaluation set:
|
|
35 |
Here is the heatmap of the similarity score of the first 30 samples on the test split of the ROCO dataset of images vs their captions:
|
36 |
![heatmap](https://imgur.com/fPFM694.png)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
### Training hyperparameters
|
39 |
|
40 |
The following hyperparameters were used during training:
|
@@ -98,7 +192,7 @@ The following hyperparameters were used during training:
|
|
98 |
| 0.0974 | 4.13 | 22500 | 0.3388 |
|
99 |
|
100 |
|
101 |
-
|
102 |
|
103 |
- Transformers 4.31.0.dev0
|
104 |
- Pytorch 2.0.1+cu117
|
|
|
35 |
Here is the heatmap of the similarity score of the first 30 samples on the test split of the ROCO dataset of images vs their captions:
|
36 |
![heatmap](https://imgur.com/fPFM694.png)
|
37 |
|
38 |
+
## Applications
|
39 |
+
|
40 |
+
### Image Retrieval
|
41 |
+
This model can be utilized for image retrieval purposes, as demonstrated below:
|
42 |
+
|
43 |
+
#### Save Image Embeddings
|
44 |
+
```
|
45 |
+
from PIL import Image
|
46 |
+
import pickle, os, torch
|
47 |
+
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
|
48 |
+
|
49 |
+
# load model
|
50 |
+
model = VisionTextDualEncoderModel.from_pretrained("kaveh/rclip")
|
51 |
+
processor = VisionTextDualEncoderProcessor.from_pretrained("kaveh/rclip")
|
52 |
+
|
53 |
+
# TO-DO
|
54 |
+
images_path = "/path/to/images/"
|
55 |
+
images = [os.path.join(images_path,i) for i in os.listdir(images_path) if i.endswith(".jpg")]
|
56 |
+
|
57 |
+
# generate embeddings of images in your dataset
|
58 |
+
image_embeds = []
|
59 |
+
for img in images:
|
60 |
+
with torch.no_grad():
|
61 |
+
inputs = processor(text=None, images=Image.open(img), return_tensors="pt", padding=True)
|
62 |
+
outputs = model.get_image_features(**inputs)[0].numpy()
|
63 |
+
image_embeds.append(outputs)
|
64 |
+
|
65 |
+
# save images embeddings in a pickle file
|
66 |
+
with open("embeddings.pkl", 'wb') as f:
|
67 |
+
pickle.dump(np.array(image_embeds), f)
|
68 |
+
```
|
69 |
+
#### Query for Images
|
70 |
+
```
|
71 |
+
import numpy as np
|
72 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
73 |
+
from PIL import Image
|
74 |
+
import pickle
|
75 |
+
import torch
|
76 |
+
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
|
77 |
+
|
78 |
+
# search a query in embeddings
|
79 |
+
query = "Chest X-Ray photos"
|
80 |
+
|
81 |
+
# embed the query
|
82 |
+
inputs = processor(text=query, images=None, return_tensors="pt", padding=True)
|
83 |
+
with torch.no_grad():
|
84 |
+
query_embedding = model.get_text_features(**inputs)[0].numpy()
|
85 |
+
|
86 |
+
# load image embeddings
|
87 |
+
with open("embeddings.pkl", 'rb') as f:
|
88 |
+
image_embeds = pickle.load(f)
|
89 |
+
|
90 |
+
# find similar images indices
|
91 |
+
def find_k_similar_images(query_embedding, image_embeds, k=2):
|
92 |
+
similarities = cosine_similarity(query_embedding.reshape(1, -1), image_embeds)
|
93 |
+
closest_indices = np.argsort(similarities[0])[::-1][:k]
|
94 |
+
return closest_indices
|
95 |
+
similar_image_indices = find_k_similar_images(query_embedding, image_embeds, k=k)
|
96 |
+
|
97 |
+
# TO-DO
|
98 |
+
images_path = "/path/to/images/"
|
99 |
+
images = [os.path.join(images_path,i) for i in os.listdir(images_path) if i.endswith(".jpg")]
|
100 |
+
|
101 |
+
# get image paths
|
102 |
+
similar_image_names = [images[index] for index in similar_image_indices]
|
103 |
+
Image.open(similar_image_names[0])
|
104 |
+
```
|
105 |
+
|
106 |
+
### Zero-Shot Image Classification
|
107 |
+
|
108 |
+
This model can be effectively employed for zero-shot image classification, as exemplified below:
|
109 |
+
```
|
110 |
+
import requests
|
111 |
+
from PIL import Image
|
112 |
+
import matplotlib.pyplot as plt
|
113 |
+
|
114 |
+
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
|
115 |
+
|
116 |
+
model = VisionTextDualEncoderModel.from_pretrained("kaveh/rclip")
|
117 |
+
processor = VisionTextDualEncoderProcessor.from_pretrained("kaveh/rclip")
|
118 |
+
|
119 |
+
url = "https://huggingface.co/spaces/kaveh/radiology-image-retrieval/resolve/main/images/ROCO_09402.jpg"
|
120 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
121 |
+
possible_class_names = ["Chest X-Ray", "Brain MRI", "Abdominal CT Scan", "Ultrasound", "OPG"]
|
122 |
+
|
123 |
+
inputs = processor(text=possible_class_names, images=image, return_tensors="pt", padding=True)
|
124 |
+
probs = model(**inputs).logits_per_image.softmax(dim=1).squeeze()
|
125 |
+
|
126 |
+
print("".join([x[0] + ": " + x[1] + "\n" for x in zip(possible_class_names, [format(prob, ".4%") for prob in probs])]))
|
127 |
+
image
|
128 |
+
```
|
129 |
+
|
130 |
+
## Training info
|
131 |
+
|
132 |
### Training hyperparameters
|
133 |
|
134 |
The following hyperparameters were used during training:
|
|
|
192 |
| 0.0974 | 4.13 | 22500 | 0.3388 |
|
193 |
|
194 |
|
195 |
+
## Framework versions
|
196 |
|
197 |
- Transformers 4.31.0.dev0
|
198 |
- Pytorch 2.0.1+cu117
|