CallmeKaito commited on
Commit
ed002eb
1 Parent(s): 493c54a

Upload 3 files

Browse files
Files changed (3) hide show
  1. models/CLIP.py +141 -0
  2. models/LLaVa.py +140 -0
  3. models/SAM.py +205 -0
models/CLIP.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ get_ipython().system('pip install ftfy regex tqdm')
8
+ get_ipython().system('pip install git+https://github.com/openai/CLIP.git')
9
+ get_ipython().system('pip install sentencepiece-0.1.98-cp311-cp311-win_amd64.whl')
10
+
11
+
12
+
13
+ # In[5]:
14
+
15
+
16
+ # prompt: install transformers
17
+
18
+ get_ipython().system('pip install transformers')
19
+
20
+
21
+ # In[6]:
22
+
23
+
24
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
25
+
26
+
27
+ feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
28
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
29
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
30
+
31
+
32
+ # ## Import the necessary libraries and load the CLIP model:
33
+
34
+ # In[7]:
35
+
36
+
37
+ from PIL import Image
38
+ import clip
39
+ import torch
40
+
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ clip_model, preprocess = clip.load("ViT-B/32", device=device)
43
+
44
+
45
+ # ## Define a function to generate product descriptions:
46
+
47
+ # In[8]:
48
+
49
+
50
+ image = Image.open("data/download.jpeg")
51
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
52
+ output_ids = model.generate(pixel_values, max_length=50, num_beams=4, early_stopping=True)
53
+ captions = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
54
+
55
+
56
+ # In[9]:
57
+
58
+
59
+ image = preprocess(image).unsqueeze(0).to(device)
60
+ with torch.no_grad():
61
+ image_features = clip_model.encode_image(image)
62
+
63
+ text_inputs = torch.cat([clip.tokenize(caption).to(device) for caption in captions]).to(device)
64
+ with torch.no_grad():
65
+ text_features = clip_model.encode_text(text_inputs)
66
+
67
+ similarity_scores = image_features @ text_features.T
68
+ best_caption_idx = similarity_scores.argmax().item()
69
+ product_description = captions[best_caption_idx]
70
+ print(product_description)
71
+
72
+
73
+ # # Using SigLip
74
+
75
+ # In[11]:
76
+
77
+
78
+ get_ipython().system('pip install sentencepiece')
79
+ get_ipython().system('pip install protobuf')
80
+
81
+
82
+ # In[12]:
83
+
84
+
85
+ from transformers import AutoProcessor, AutoModel, VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
86
+ import torch
87
+ from PIL import Image
88
+
89
+
90
+ model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
91
+ processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
92
+
93
+
94
+ image = Image.open("data/avito4.jpeg")
95
+ inputs = processor(images=image, return_tensors="pt")
96
+
97
+
98
+ feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
99
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
100
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
101
+
102
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
103
+ output_ids = model.generate(pixel_values, max_length=100, num_beams=5, early_stopping=True)
104
+ captions = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
105
+
106
+ image = preprocess(image).unsqueeze(0).to(device)
107
+ with torch.no_grad():
108
+ image_features = clip_model.encode_image(image)
109
+
110
+ text_inputs = torch.cat([clip.tokenize(caption).to(device) for caption in captions]).to(device)
111
+ with torch.no_grad():
112
+ text_features = clip_model.encode_text(text_inputs)
113
+
114
+ similarity_scores = image_features @ text_features.T
115
+ best_caption_idx = similarity_scores.argmax().item()
116
+ product_description = captions[best_caption_idx]
117
+ print(product_description)
118
+
119
+ # a vase sitting on a shelf in a store => thuya
120
+ # a wooden bench sitting on top of a wooden floor => avito
121
+ ## two old fashioned vases sitting next to each other => avito2
122
+ ## three wooden vases sitting on top of a wooden floor => avito3
123
+ # an old fashioned clock sitting on top of a table => avito4
124
+
125
+
126
+
127
+ # In[ ]:
128
+
129
+
130
+
131
+
132
+
133
+ # # Implemeting LLaVa
134
+
135
+ # https://colab.research.google.com/drive/1veefV17NcD1S4ou4nF8ABkfm8-TgU0Dr#scrollTo=XN2vJCPZk1UY
136
+
137
+ # In[ ]:
138
+
139
+
140
+
141
+
models/LLaVa.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # # Set-up environment
5
+
6
+ # In[2]:
7
+
8
+
9
+ get_ipython().system('pip install --upgrade -q accelerate bitsandbytes')
10
+
11
+
12
+ # In[ ]:
13
+
14
+
15
+ get_ipython().system('rm -r transformers')
16
+ get_ipython().system('git clone -b llava_improvements https://github.com/NielsRogge/transformers.git')
17
+ get_ipython().system('cd transformers')
18
+ get_ipython().system('pip install -q ./transformers')
19
+
20
+
21
+ # In[ ]:
22
+
23
+
24
+ get_ipython().system('pip install git+https://github.com/huggingface/transformers.git')
25
+
26
+
27
+ # ## Load model and processor
28
+
29
+ # In[ ]:
30
+
31
+
32
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
33
+ from transformers import BitsAndBytesConfig
34
+ import torch
35
+
36
+ quantization_config = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16
39
+ )
40
+
41
+
42
+ model_id = "llava-hf/llava-1.5-7b-hf"
43
+
44
+ processor = AutoProcessor.from_pretrained(model_id)
45
+ model = LlavaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
46
+
47
+
48
+ # ## Prepare image and text for the model
49
+
50
+ # In[ ]:
51
+
52
+
53
+ import requests
54
+ from PIL import Image
55
+
56
+ image1 = Image.open('data/clock.jpeg')
57
+ display(image1)
58
+
59
+
60
+ # In the prompt, you can refer to images using the special \<image> token. To indicate which text comes from a human vs. the model, one uses USER and ASSISTANT respectively. The format looks as follows:
61
+ #
62
+ # ```bash
63
+ # USER: <image>\n<prompt>\nASSISTANT:
64
+ # ```
65
+
66
+ # In other words, you always need to end your prompt with `ASSISTANT:`. Here we will perform batched generation (i.e generating on several prompts).
67
+
68
+ # In[ ]:
69
+
70
+
71
+ caption = 'an old fashioned clock sitting on top of a table'
72
+
73
+ user_input = "This is an intricately crafted old-fashioned clock created by a skilled Moroccan artisan back in 1988 from Chefchaoune.. it reminds me of my mother."
74
+
75
+ prompts = [
76
+ f"USER: <image>\nBased on the caption '{caption}' and the following user input: '{user_input}', generate a detailed product name and description for this Moroccan artisanal item; the description should be minimal yet it gives the essence of the product and convinces people to buy or express their interest in it.\nASSISTANT:"
77
+ # f"""
78
+ # USER: <image>\nBased on the image caption '{caption}' and the following background information: '{user_input}', generate an attention-grabbing yet concise product name and description for this authentic Moroccan artisanal item. The description should:
79
+ # Highlight the key features and unique selling points that make this product exceptional and desirable.
80
+ # Convey the cultural significance, craftsmanship, and rich heritage behind the item's creation.
81
+ # Use evocative language that resonates with potential buyers and piques their interest in owning this one-of-a-kind piece.
82
+ # Be concise, direct, and persuasive, leaving the reader eager to learn more or acquire the product.
83
+
84
+ # Your response should follow this format:
85
+ # Product Name: [Compelling and relevant product name]
86
+ # Product Description: [Concise yet captivating description addressing the points above]
87
+ # ASSISTANT:"""
88
+
89
+ ]
90
+
91
+ inputs = processor(prompts, images=[image1], padding=True, return_tensors="pt").to("cuda")
92
+ for k,v in inputs.items():
93
+ print(k,v.shape)
94
+
95
+
96
+ # ## Autoregressively generate completion
97
+ #
98
+ # Finally, we simply let the model predict the next tokens given the images + prompt. Of course one can adjust all the [generation parameters](https://huggingface.co/docs/transformers/v4.35.2/en/main_classes/text_generation#transformers.GenerationMixin.generate). By default, greedy decoding is used.
99
+
100
+ # In[ ]:
101
+
102
+
103
+ output = model.generate(**inputs, max_new_tokens=200)
104
+ generated_text = processor.batch_decode(output, skip_special_tokens=True)
105
+ for text in generated_text:
106
+ print(text.split("ASSISTANT:")[-1])
107
+
108
+
109
+ # ## Pipeline API
110
+ #
111
+ # Alternatively, you can leverage the [pipeline](https://huggingface.co/docs/transformers/main_classes/pipelines) API which abstracts all of the logic above away for the user. We also provide the quantization config to make sure we leverage 4-bit inference.
112
+
113
+ # In[ ]:
114
+
115
+
116
+ from transformers import pipeline
117
+
118
+ pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
119
+
120
+
121
+ # In[ ]:
122
+
123
+
124
+ max_new_tokens = 200
125
+ prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place?\nASSISTANT:"
126
+
127
+ outputs = pipe(image1, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
128
+
129
+
130
+ # In[ ]:
131
+
132
+
133
+ print(outputs[0]["generated_text"])
134
+
135
+
136
+ # In[ ]:
137
+
138
+
139
+
140
+
models/SAM.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # # Utility functions
5
+
6
+ # In[ ]:
7
+
8
+
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+
12
+ def show_mask(mask, ax, random_color=False):
13
+ if random_color:
14
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
15
+ else:
16
+ color = np.array([30/255, 144/255, 255/255, 0.6])
17
+ h, w = mask.shape[-2:]
18
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
19
+ ax.imshow(mask_image)
20
+
21
+
22
+ def show_box(box, ax):
23
+ x0, y0 = box[0], box[1]
24
+ w, h = box[2] - box[0], box[3] - box[1]
25
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
26
+
27
+ def show_boxes_on_image(raw_image, boxes):
28
+ plt.figure(figsize=(10,10))
29
+ plt.imshow(raw_image)
30
+ for box in boxes:
31
+ show_box(box, plt.gca())
32
+ plt.axis('on')
33
+ plt.show()
34
+
35
+ def show_points_on_image(raw_image, input_points, input_labels=None):
36
+ plt.figure(figsize=(10,10))
37
+ plt.imshow(raw_image)
38
+ input_points = np.array(input_points)
39
+ if input_labels is None:
40
+ labels = np.ones_like(input_points[:, 0])
41
+ else:
42
+ labels = np.array(input_labels)
43
+ show_points(input_points, labels, plt.gca())
44
+ plt.axis('on')
45
+ plt.show()
46
+
47
+ def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
48
+ plt.figure(figsize=(10,10))
49
+ plt.imshow(raw_image)
50
+ input_points = np.array(input_points)
51
+ if input_labels is None:
52
+ labels = np.ones_like(input_points[:, 0])
53
+ else:
54
+ labels = np.array(input_labels)
55
+ show_points(input_points, labels, plt.gca())
56
+ for box in boxes:
57
+ show_box(box, plt.gca())
58
+ plt.axis('on')
59
+ plt.show()
60
+
61
+
62
+ def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
63
+ plt.figure(figsize=(10,10))
64
+ plt.imshow(raw_image)
65
+ input_points = np.array(input_points)
66
+ if input_labels is None:
67
+ labels = np.ones_like(input_points[:, 0])
68
+ else:
69
+ labels = np.array(input_labels)
70
+ show_points(input_points, labels, plt.gca())
71
+ for box in boxes:
72
+ show_box(box, plt.gca())
73
+ plt.axis('on')
74
+ plt.show()
75
+
76
+
77
+ def show_points(coords, labels, ax, marker_size=375):
78
+ pos_points = coords[labels==1]
79
+ neg_points = coords[labels==0]
80
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
81
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
82
+
83
+
84
+ def show_masks_on_image(raw_image, masks, scores):
85
+ if len(masks.shape) == 4:
86
+ masks = masks.squeeze()
87
+ if scores.shape[0] == 1:
88
+ scores = scores.squeeze()
89
+
90
+ nb_predictions = scores.shape[-1]
91
+ fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
92
+
93
+ for i, (mask, score) in enumerate(zip(masks, scores)):
94
+ mask = mask.cpu().detach()
95
+ axes[i].imshow(np.array(raw_image))
96
+ show_mask(mask, axes[i])
97
+ axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
98
+ axes[i].axis("off")
99
+ plt.show()
100
+
101
+
102
+ # # Model loading
103
+
104
+ # In[ ]:
105
+
106
+
107
+ import torch
108
+ from transformers import SamModel, SamProcessor
109
+
110
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
112
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
113
+
114
+
115
+ # In[ ]:
116
+
117
+
118
+ from PIL import Image
119
+ import requests
120
+
121
+ img_url = "thuya.jpeg"
122
+ raw_image = Image.open(img_url)
123
+
124
+ plt.imshow(raw_image)
125
+
126
+
127
+ # ## Step 1: Retrieve the image embeddings
128
+
129
+ # In[ ]:
130
+
131
+
132
+ inputs = processor(raw_image, return_tensors="pt").to(device)
133
+ image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
134
+
135
+
136
+ # In[ ]:
137
+
138
+
139
+ input_points = [[[200, 300]]]
140
+ show_points_on_image(raw_image, input_points[0])
141
+
142
+
143
+ # In[ ]:
144
+
145
+
146
+ inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device)
147
+ # pop the pixel_values as they are not neded
148
+ inputs.pop("pixel_values", None)
149
+ inputs.update({"image_embeddings": image_embeddings})
150
+
151
+ with torch.no_grad():
152
+ outputs = model(**inputs)
153
+
154
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
155
+ scores = outputs.iou_scores
156
+
157
+
158
+ # In[ ]:
159
+
160
+
161
+ show_masks_on_image(raw_image, masks[0], scores)
162
+
163
+
164
+ # ## Export the masked images
165
+
166
+ # In[92]:
167
+
168
+
169
+ import cv2
170
+
171
+ if len(masks[0].shape) == 4:
172
+ masks[0] = masks[0].squeeze()
173
+ if scores.shape[0] == 1:
174
+ scores = scores.squeeze()
175
+
176
+ nb_predictions = scores.shape[-1]
177
+ fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
178
+ for i, (mask, score) in enumerate(zip(masks[0], scores)):
179
+ mask = mask.cpu().detach()
180
+ axes[i].imshow(np.array(raw_image))
181
+ # show_mask(mask, axes[i])
182
+
183
+ mask_image = (mask.numpy() * 255).astype(np.uint8) # Convert to uint8 format
184
+ cv2.imwrite('mask.png', mask_image)
185
+
186
+ image = cv2.imread('thuya.jpeg')
187
+
188
+ color_mask = np.zeros_like(image)
189
+ color_mask[mask > 0.5] = [30, 144, 255] # Choose any color you like
190
+ masked_image = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0)
191
+
192
+ color = np.array([30/255, 144/255, 255/255])
193
+ #mask_image = * color.reshape(1, 1, -1)
194
+
195
+ new_image = -image* np.tile(mask_image[...,None], 3)
196
+
197
+ cv2.imwrite('masked_image2.png', cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR))
198
+
199
+
200
+
201
+ # In[85]:
202
+
203
+
204
+ .shape
205
+