Spaces:
Running
Running
fix encoder loading
Browse files- localization.py +17 -38
- utils.py +8 -6
localization.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
from text2image import get_model, get_tokenizer, get_image_transform
|
3 |
from utils import text_encoder
|
4 |
-
from
|
5 |
from PIL import Image
|
6 |
from jax import numpy as jnp
|
7 |
import pandas as pd
|
@@ -13,7 +13,16 @@ import jax
|
|
13 |
import gc
|
14 |
|
15 |
|
16 |
-
preprocess =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def resize_longer(image, longer_size=224):
|
@@ -89,18 +98,16 @@ def gen_image_batch(image_url, image_size=224, pixel_size=10):
|
|
89 |
|
90 |
|
91 |
def get_heatmap(image_url, text, pixel_size=10, iterations=3):
|
92 |
-
|
93 |
model = get_model()
|
94 |
image_size = model.config.vision_config.image_size
|
95 |
|
96 |
images, masks, vertical, horizontal = gen_image_batch(image_url, pixel_size=pixel_size)
|
97 |
input_image = images[0].copy()
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
text_embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[0]
|
103 |
-
text_embedding = text_embedding / jnp.linalg.norm(text_embedding, axis=-1, keepdims=True)
|
104 |
|
105 |
vertical_scores = jnp.zeros((masks[0].shape[1], 512))
|
106 |
vertical_masks = jnp.zeros((masks[0].shape[1], 1))
|
@@ -131,39 +138,11 @@ def get_heatmap(image_url, text, pixel_size=10, iterations=3):
|
|
131 |
embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
|
132 |
full_embs = jnp.minimum(embs_1, embs_2)
|
133 |
mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1)
|
134 |
-
|
135 |
-
print(full_embs.shape)
|
136 |
-
|
137 |
-
#full_embs = full_embs / jnp.linalg.norm(full_embs, axis=-1, keepdims=True)
|
138 |
full_embs = (full_embs / mask_sum)
|
139 |
|
140 |
orig_shape = full_embs.shape
|
141 |
-
sims = jnp.matmul(jnp.reshape(full_embs, (-1, 512)),
|
142 |
-
|
143 |
-
#sims = jax.nn.relu(sims)
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
# mean_vertical_scores = vertical_scores / vertical_masks
|
151 |
-
# mean_horizontal_scores = horizontal_scores / horizontal_masks
|
152 |
-
|
153 |
-
# print(mean_vertical_score)
|
154 |
-
# print(mean_horizontal_score)
|
155 |
-
|
156 |
-
# score = jnp.matmul(mean_vertical_scores, mean_horizontal_scores.T)
|
157 |
-
|
158 |
-
#mask = jnp.matmul(vertical_masks, horizontal_scores.T)
|
159 |
-
#score = score / mask
|
160 |
-
|
161 |
-
score = sims # jnp.expand_dims(score.T, axis=-1)
|
162 |
-
#score = jax.nn.relu(score) / jnp.max(jnp.abs(score))
|
163 |
-
|
164 |
-
#score = jax.nn.relu(score - sims[0])
|
165 |
-
|
166 |
-
# score = jnp.square(score)
|
167 |
|
168 |
for i in range(iterations):
|
169 |
score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
|
|
|
1 |
import streamlit as st
|
2 |
from text2image import get_model, get_tokenizer, get_image_transform
|
3 |
from utils import text_encoder
|
4 |
+
from torchvision import transforms
|
5 |
from PIL import Image
|
6 |
from jax import numpy as jnp
|
7 |
import pandas as pd
|
|
|
13 |
import gc
|
14 |
|
15 |
|
16 |
+
preprocess = transforms.Compose(
|
17 |
+
[
|
18 |
+
transforms.ToTensor(),
|
19 |
+
transforms.Resize(224),
|
20 |
+
transforms.Normalize(
|
21 |
+
(0.48145466, 0.4578275, 0.40821073),
|
22 |
+
(0.26862954, 0.26130258, 0.27577711)
|
23 |
+
),
|
24 |
+
]
|
25 |
+
)
|
26 |
|
27 |
|
28 |
def resize_longer(image, longer_size=224):
|
|
|
98 |
|
99 |
|
100 |
def get_heatmap(image_url, text, pixel_size=10, iterations=3):
|
101 |
+
tokenizer = get_tokenizer()
|
102 |
model = get_model()
|
103 |
image_size = model.config.vision_config.image_size
|
104 |
|
105 |
images, masks, vertical, horizontal = gen_image_batch(image_url, pixel_size=pixel_size)
|
106 |
input_image = images[0].copy()
|
107 |
|
108 |
+
images = np.stack([preprocess(pad_to_square(image)) for image in images], axis=0)
|
109 |
+
image_embeddings, embedding_norms = image_encoder(images, model)
|
110 |
+
text_embeddings, _ = text_encoder(text, model, tokenizer)
|
|
|
|
|
111 |
|
112 |
vertical_scores = jnp.zeros((masks[0].shape[1], 512))
|
113 |
vertical_masks = jnp.zeros((masks[0].shape[1], 1))
|
|
|
138 |
embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
|
139 |
full_embs = jnp.minimum(embs_1, embs_2)
|
140 |
mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1)
|
|
|
|
|
|
|
|
|
141 |
full_embs = (full_embs / mask_sum)
|
142 |
|
143 |
orig_shape = full_embs.shape
|
144 |
+
sims = jnp.matmul(jnp.reshape(full_embs, (-1, 512)), text_embeddings.T)
|
145 |
+
score = jnp.reshape(sims, (*orig_shape[:2], 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
for i in range(iterations):
|
148 |
score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
|
utils.py
CHANGED
@@ -34,18 +34,20 @@ def text_encoder(text, model, tokenizer):
|
|
34 |
padding="max_length",
|
35 |
return_tensors="np",
|
36 |
)
|
37 |
-
embedding = model.get_text_features(
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
42 |
|
43 |
|
44 |
def image_encoder(image, model):
|
45 |
image = image.permute(1, 2, 0).numpy()
|
46 |
image = jnp.expand_dims(image, axis=0) # add batch size
|
47 |
features = model.get_image_features(image,)
|
48 |
-
|
|
|
49 |
return features
|
50 |
|
51 |
|
|
|
34 |
padding="max_length",
|
35 |
return_tensors="np",
|
36 |
)
|
37 |
+
embedding = model.get_text_features(
|
38 |
+
inputs["input_ids"],
|
39 |
+
inputs["attention_mask"])[0]
|
40 |
+
norms = jnp.linalg.norm(embedding, axis=-1, keepdims=True)
|
41 |
+
embedding = embedding / norms
|
42 |
+
return jnp.expand_dims(embedding, axis=0), norms
|
43 |
|
44 |
|
45 |
def image_encoder(image, model):
|
46 |
image = image.permute(1, 2, 0).numpy()
|
47 |
image = jnp.expand_dims(image, axis=0) # add batch size
|
48 |
features = model.get_image_features(image,)
|
49 |
+
norms = jnp.linalg.norm(features, axis=-1, keepdims=True)
|
50 |
+
features = features / norms
|
51 |
return features
|
52 |
|
53 |
|