Spaces:
Build error
Build error
CallmeKaito
commited on
Commit
•
ed002eb
1
Parent(s):
493c54a
Upload 3 files
Browse files- models/CLIP.py +141 -0
- models/LLaVa.py +140 -0
- models/SAM.py +205 -0
models/CLIP.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[1]:
|
5 |
+
|
6 |
+
|
7 |
+
get_ipython().system('pip install ftfy regex tqdm')
|
8 |
+
get_ipython().system('pip install git+https://github.com/openai/CLIP.git')
|
9 |
+
get_ipython().system('pip install sentencepiece-0.1.98-cp311-cp311-win_amd64.whl')
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
# In[5]:
|
14 |
+
|
15 |
+
|
16 |
+
# prompt: install transformers
|
17 |
+
|
18 |
+
get_ipython().system('pip install transformers')
|
19 |
+
|
20 |
+
|
21 |
+
# In[6]:
|
22 |
+
|
23 |
+
|
24 |
+
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
|
25 |
+
|
26 |
+
|
27 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
29 |
+
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
30 |
+
|
31 |
+
|
32 |
+
# ## Import the necessary libraries and load the CLIP model:
|
33 |
+
|
34 |
+
# In[7]:
|
35 |
+
|
36 |
+
|
37 |
+
from PIL import Image
|
38 |
+
import clip
|
39 |
+
import torch
|
40 |
+
|
41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
clip_model, preprocess = clip.load("ViT-B/32", device=device)
|
43 |
+
|
44 |
+
|
45 |
+
# ## Define a function to generate product descriptions:
|
46 |
+
|
47 |
+
# In[8]:
|
48 |
+
|
49 |
+
|
50 |
+
image = Image.open("data/download.jpeg")
|
51 |
+
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
52 |
+
output_ids = model.generate(pixel_values, max_length=50, num_beams=4, early_stopping=True)
|
53 |
+
captions = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
54 |
+
|
55 |
+
|
56 |
+
# In[9]:
|
57 |
+
|
58 |
+
|
59 |
+
image = preprocess(image).unsqueeze(0).to(device)
|
60 |
+
with torch.no_grad():
|
61 |
+
image_features = clip_model.encode_image(image)
|
62 |
+
|
63 |
+
text_inputs = torch.cat([clip.tokenize(caption).to(device) for caption in captions]).to(device)
|
64 |
+
with torch.no_grad():
|
65 |
+
text_features = clip_model.encode_text(text_inputs)
|
66 |
+
|
67 |
+
similarity_scores = image_features @ text_features.T
|
68 |
+
best_caption_idx = similarity_scores.argmax().item()
|
69 |
+
product_description = captions[best_caption_idx]
|
70 |
+
print(product_description)
|
71 |
+
|
72 |
+
|
73 |
+
# # Using SigLip
|
74 |
+
|
75 |
+
# In[11]:
|
76 |
+
|
77 |
+
|
78 |
+
get_ipython().system('pip install sentencepiece')
|
79 |
+
get_ipython().system('pip install protobuf')
|
80 |
+
|
81 |
+
|
82 |
+
# In[12]:
|
83 |
+
|
84 |
+
|
85 |
+
from transformers import AutoProcessor, AutoModel, VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
|
86 |
+
import torch
|
87 |
+
from PIL import Image
|
88 |
+
|
89 |
+
|
90 |
+
model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
|
91 |
+
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
92 |
+
|
93 |
+
|
94 |
+
image = Image.open("data/avito4.jpeg")
|
95 |
+
inputs = processor(images=image, return_tensors="pt")
|
96 |
+
|
97 |
+
|
98 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
100 |
+
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
101 |
+
|
102 |
+
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
103 |
+
output_ids = model.generate(pixel_values, max_length=100, num_beams=5, early_stopping=True)
|
104 |
+
captions = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
105 |
+
|
106 |
+
image = preprocess(image).unsqueeze(0).to(device)
|
107 |
+
with torch.no_grad():
|
108 |
+
image_features = clip_model.encode_image(image)
|
109 |
+
|
110 |
+
text_inputs = torch.cat([clip.tokenize(caption).to(device) for caption in captions]).to(device)
|
111 |
+
with torch.no_grad():
|
112 |
+
text_features = clip_model.encode_text(text_inputs)
|
113 |
+
|
114 |
+
similarity_scores = image_features @ text_features.T
|
115 |
+
best_caption_idx = similarity_scores.argmax().item()
|
116 |
+
product_description = captions[best_caption_idx]
|
117 |
+
print(product_description)
|
118 |
+
|
119 |
+
# a vase sitting on a shelf in a store => thuya
|
120 |
+
# a wooden bench sitting on top of a wooden floor => avito
|
121 |
+
## two old fashioned vases sitting next to each other => avito2
|
122 |
+
## three wooden vases sitting on top of a wooden floor => avito3
|
123 |
+
# an old fashioned clock sitting on top of a table => avito4
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
# In[ ]:
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
# # Implemeting LLaVa
|
134 |
+
|
135 |
+
# https://colab.research.google.com/drive/1veefV17NcD1S4ou4nF8ABkfm8-TgU0Dr#scrollTo=XN2vJCPZk1UY
|
136 |
+
|
137 |
+
# In[ ]:
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
models/LLaVa.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# # Set-up environment
|
5 |
+
|
6 |
+
# In[2]:
|
7 |
+
|
8 |
+
|
9 |
+
get_ipython().system('pip install --upgrade -q accelerate bitsandbytes')
|
10 |
+
|
11 |
+
|
12 |
+
# In[ ]:
|
13 |
+
|
14 |
+
|
15 |
+
get_ipython().system('rm -r transformers')
|
16 |
+
get_ipython().system('git clone -b llava_improvements https://github.com/NielsRogge/transformers.git')
|
17 |
+
get_ipython().system('cd transformers')
|
18 |
+
get_ipython().system('pip install -q ./transformers')
|
19 |
+
|
20 |
+
|
21 |
+
# In[ ]:
|
22 |
+
|
23 |
+
|
24 |
+
get_ipython().system('pip install git+https://github.com/huggingface/transformers.git')
|
25 |
+
|
26 |
+
|
27 |
+
# ## Load model and processor
|
28 |
+
|
29 |
+
# In[ ]:
|
30 |
+
|
31 |
+
|
32 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
33 |
+
from transformers import BitsAndBytesConfig
|
34 |
+
import torch
|
35 |
+
|
36 |
+
quantization_config = BitsAndBytesConfig(
|
37 |
+
load_in_4bit=True,
|
38 |
+
bnb_4bit_compute_dtype=torch.float16
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
model_id = "llava-hf/llava-1.5-7b-hf"
|
43 |
+
|
44 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
45 |
+
model = LlavaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
|
46 |
+
|
47 |
+
|
48 |
+
# ## Prepare image and text for the model
|
49 |
+
|
50 |
+
# In[ ]:
|
51 |
+
|
52 |
+
|
53 |
+
import requests
|
54 |
+
from PIL import Image
|
55 |
+
|
56 |
+
image1 = Image.open('data/clock.jpeg')
|
57 |
+
display(image1)
|
58 |
+
|
59 |
+
|
60 |
+
# In the prompt, you can refer to images using the special \<image> token. To indicate which text comes from a human vs. the model, one uses USER and ASSISTANT respectively. The format looks as follows:
|
61 |
+
#
|
62 |
+
# ```bash
|
63 |
+
# USER: <image>\n<prompt>\nASSISTANT:
|
64 |
+
# ```
|
65 |
+
|
66 |
+
# In other words, you always need to end your prompt with `ASSISTANT:`. Here we will perform batched generation (i.e generating on several prompts).
|
67 |
+
|
68 |
+
# In[ ]:
|
69 |
+
|
70 |
+
|
71 |
+
caption = 'an old fashioned clock sitting on top of a table'
|
72 |
+
|
73 |
+
user_input = "This is an intricately crafted old-fashioned clock created by a skilled Moroccan artisan back in 1988 from Chefchaoune.. it reminds me of my mother."
|
74 |
+
|
75 |
+
prompts = [
|
76 |
+
f"USER: <image>\nBased on the caption '{caption}' and the following user input: '{user_input}', generate a detailed product name and description for this Moroccan artisanal item; the description should be minimal yet it gives the essence of the product and convinces people to buy or express their interest in it.\nASSISTANT:"
|
77 |
+
# f"""
|
78 |
+
# USER: <image>\nBased on the image caption '{caption}' and the following background information: '{user_input}', generate an attention-grabbing yet concise product name and description for this authentic Moroccan artisanal item. The description should:
|
79 |
+
# Highlight the key features and unique selling points that make this product exceptional and desirable.
|
80 |
+
# Convey the cultural significance, craftsmanship, and rich heritage behind the item's creation.
|
81 |
+
# Use evocative language that resonates with potential buyers and piques their interest in owning this one-of-a-kind piece.
|
82 |
+
# Be concise, direct, and persuasive, leaving the reader eager to learn more or acquire the product.
|
83 |
+
|
84 |
+
# Your response should follow this format:
|
85 |
+
# Product Name: [Compelling and relevant product name]
|
86 |
+
# Product Description: [Concise yet captivating description addressing the points above]
|
87 |
+
# ASSISTANT:"""
|
88 |
+
|
89 |
+
]
|
90 |
+
|
91 |
+
inputs = processor(prompts, images=[image1], padding=True, return_tensors="pt").to("cuda")
|
92 |
+
for k,v in inputs.items():
|
93 |
+
print(k,v.shape)
|
94 |
+
|
95 |
+
|
96 |
+
# ## Autoregressively generate completion
|
97 |
+
#
|
98 |
+
# Finally, we simply let the model predict the next tokens given the images + prompt. Of course one can adjust all the [generation parameters](https://huggingface.co/docs/transformers/v4.35.2/en/main_classes/text_generation#transformers.GenerationMixin.generate). By default, greedy decoding is used.
|
99 |
+
|
100 |
+
# In[ ]:
|
101 |
+
|
102 |
+
|
103 |
+
output = model.generate(**inputs, max_new_tokens=200)
|
104 |
+
generated_text = processor.batch_decode(output, skip_special_tokens=True)
|
105 |
+
for text in generated_text:
|
106 |
+
print(text.split("ASSISTANT:")[-1])
|
107 |
+
|
108 |
+
|
109 |
+
# ## Pipeline API
|
110 |
+
#
|
111 |
+
# Alternatively, you can leverage the [pipeline](https://huggingface.co/docs/transformers/main_classes/pipelines) API which abstracts all of the logic above away for the user. We also provide the quantization config to make sure we leverage 4-bit inference.
|
112 |
+
|
113 |
+
# In[ ]:
|
114 |
+
|
115 |
+
|
116 |
+
from transformers import pipeline
|
117 |
+
|
118 |
+
pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
|
119 |
+
|
120 |
+
|
121 |
+
# In[ ]:
|
122 |
+
|
123 |
+
|
124 |
+
max_new_tokens = 200
|
125 |
+
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place?\nASSISTANT:"
|
126 |
+
|
127 |
+
outputs = pipe(image1, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
|
128 |
+
|
129 |
+
|
130 |
+
# In[ ]:
|
131 |
+
|
132 |
+
|
133 |
+
print(outputs[0]["generated_text"])
|
134 |
+
|
135 |
+
|
136 |
+
# In[ ]:
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
models/SAM.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# # Utility functions
|
5 |
+
|
6 |
+
# In[ ]:
|
7 |
+
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
def show_mask(mask, ax, random_color=False):
|
13 |
+
if random_color:
|
14 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
15 |
+
else:
|
16 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
17 |
+
h, w = mask.shape[-2:]
|
18 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
19 |
+
ax.imshow(mask_image)
|
20 |
+
|
21 |
+
|
22 |
+
def show_box(box, ax):
|
23 |
+
x0, y0 = box[0], box[1]
|
24 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
25 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
26 |
+
|
27 |
+
def show_boxes_on_image(raw_image, boxes):
|
28 |
+
plt.figure(figsize=(10,10))
|
29 |
+
plt.imshow(raw_image)
|
30 |
+
for box in boxes:
|
31 |
+
show_box(box, plt.gca())
|
32 |
+
plt.axis('on')
|
33 |
+
plt.show()
|
34 |
+
|
35 |
+
def show_points_on_image(raw_image, input_points, input_labels=None):
|
36 |
+
plt.figure(figsize=(10,10))
|
37 |
+
plt.imshow(raw_image)
|
38 |
+
input_points = np.array(input_points)
|
39 |
+
if input_labels is None:
|
40 |
+
labels = np.ones_like(input_points[:, 0])
|
41 |
+
else:
|
42 |
+
labels = np.array(input_labels)
|
43 |
+
show_points(input_points, labels, plt.gca())
|
44 |
+
plt.axis('on')
|
45 |
+
plt.show()
|
46 |
+
|
47 |
+
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
|
48 |
+
plt.figure(figsize=(10,10))
|
49 |
+
plt.imshow(raw_image)
|
50 |
+
input_points = np.array(input_points)
|
51 |
+
if input_labels is None:
|
52 |
+
labels = np.ones_like(input_points[:, 0])
|
53 |
+
else:
|
54 |
+
labels = np.array(input_labels)
|
55 |
+
show_points(input_points, labels, plt.gca())
|
56 |
+
for box in boxes:
|
57 |
+
show_box(box, plt.gca())
|
58 |
+
plt.axis('on')
|
59 |
+
plt.show()
|
60 |
+
|
61 |
+
|
62 |
+
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
|
63 |
+
plt.figure(figsize=(10,10))
|
64 |
+
plt.imshow(raw_image)
|
65 |
+
input_points = np.array(input_points)
|
66 |
+
if input_labels is None:
|
67 |
+
labels = np.ones_like(input_points[:, 0])
|
68 |
+
else:
|
69 |
+
labels = np.array(input_labels)
|
70 |
+
show_points(input_points, labels, plt.gca())
|
71 |
+
for box in boxes:
|
72 |
+
show_box(box, plt.gca())
|
73 |
+
plt.axis('on')
|
74 |
+
plt.show()
|
75 |
+
|
76 |
+
|
77 |
+
def show_points(coords, labels, ax, marker_size=375):
|
78 |
+
pos_points = coords[labels==1]
|
79 |
+
neg_points = coords[labels==0]
|
80 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
81 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
82 |
+
|
83 |
+
|
84 |
+
def show_masks_on_image(raw_image, masks, scores):
|
85 |
+
if len(masks.shape) == 4:
|
86 |
+
masks = masks.squeeze()
|
87 |
+
if scores.shape[0] == 1:
|
88 |
+
scores = scores.squeeze()
|
89 |
+
|
90 |
+
nb_predictions = scores.shape[-1]
|
91 |
+
fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
|
92 |
+
|
93 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
94 |
+
mask = mask.cpu().detach()
|
95 |
+
axes[i].imshow(np.array(raw_image))
|
96 |
+
show_mask(mask, axes[i])
|
97 |
+
axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
|
98 |
+
axes[i].axis("off")
|
99 |
+
plt.show()
|
100 |
+
|
101 |
+
|
102 |
+
# # Model loading
|
103 |
+
|
104 |
+
# In[ ]:
|
105 |
+
|
106 |
+
|
107 |
+
import torch
|
108 |
+
from transformers import SamModel, SamProcessor
|
109 |
+
|
110 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
111 |
+
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
112 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
113 |
+
|
114 |
+
|
115 |
+
# In[ ]:
|
116 |
+
|
117 |
+
|
118 |
+
from PIL import Image
|
119 |
+
import requests
|
120 |
+
|
121 |
+
img_url = "thuya.jpeg"
|
122 |
+
raw_image = Image.open(img_url)
|
123 |
+
|
124 |
+
plt.imshow(raw_image)
|
125 |
+
|
126 |
+
|
127 |
+
# ## Step 1: Retrieve the image embeddings
|
128 |
+
|
129 |
+
# In[ ]:
|
130 |
+
|
131 |
+
|
132 |
+
inputs = processor(raw_image, return_tensors="pt").to(device)
|
133 |
+
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
134 |
+
|
135 |
+
|
136 |
+
# In[ ]:
|
137 |
+
|
138 |
+
|
139 |
+
input_points = [[[200, 300]]]
|
140 |
+
show_points_on_image(raw_image, input_points[0])
|
141 |
+
|
142 |
+
|
143 |
+
# In[ ]:
|
144 |
+
|
145 |
+
|
146 |
+
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device)
|
147 |
+
# pop the pixel_values as they are not neded
|
148 |
+
inputs.pop("pixel_values", None)
|
149 |
+
inputs.update({"image_embeddings": image_embeddings})
|
150 |
+
|
151 |
+
with torch.no_grad():
|
152 |
+
outputs = model(**inputs)
|
153 |
+
|
154 |
+
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
155 |
+
scores = outputs.iou_scores
|
156 |
+
|
157 |
+
|
158 |
+
# In[ ]:
|
159 |
+
|
160 |
+
|
161 |
+
show_masks_on_image(raw_image, masks[0], scores)
|
162 |
+
|
163 |
+
|
164 |
+
# ## Export the masked images
|
165 |
+
|
166 |
+
# In[92]:
|
167 |
+
|
168 |
+
|
169 |
+
import cv2
|
170 |
+
|
171 |
+
if len(masks[0].shape) == 4:
|
172 |
+
masks[0] = masks[0].squeeze()
|
173 |
+
if scores.shape[0] == 1:
|
174 |
+
scores = scores.squeeze()
|
175 |
+
|
176 |
+
nb_predictions = scores.shape[-1]
|
177 |
+
fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
|
178 |
+
for i, (mask, score) in enumerate(zip(masks[0], scores)):
|
179 |
+
mask = mask.cpu().detach()
|
180 |
+
axes[i].imshow(np.array(raw_image))
|
181 |
+
# show_mask(mask, axes[i])
|
182 |
+
|
183 |
+
mask_image = (mask.numpy() * 255).astype(np.uint8) # Convert to uint8 format
|
184 |
+
cv2.imwrite('mask.png', mask_image)
|
185 |
+
|
186 |
+
image = cv2.imread('thuya.jpeg')
|
187 |
+
|
188 |
+
color_mask = np.zeros_like(image)
|
189 |
+
color_mask[mask > 0.5] = [30, 144, 255] # Choose any color you like
|
190 |
+
masked_image = cv2.addWeighted(image, 0.6, color_mask, 0.4, 0)
|
191 |
+
|
192 |
+
color = np.array([30/255, 144/255, 255/255])
|
193 |
+
#mask_image = * color.reshape(1, 1, -1)
|
194 |
+
|
195 |
+
new_image = -image* np.tile(mask_image[...,None], 3)
|
196 |
+
|
197 |
+
cv2.imwrite('masked_image2.png', cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR))
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
# In[85]:
|
202 |
+
|
203 |
+
|
204 |
+
.shape
|
205 |
+
|