Spaces:
Runtime error
Runtime error
xichen98cn
commited on
Commit
•
e7ae87a
1
Parent(s):
5e46b82
Update frozenseg/frozenseg.py
Browse files- frozenseg/frozenseg.py +59 -22
frozenseg/frozenseg.py
CHANGED
@@ -16,6 +16,7 @@ from segment_anything.build_sam import sam_model_registry
|
|
16 |
from .modeling.transformer_decoder.frozenseg_transformer_decoder import MaskPooling, get_classification_logits
|
17 |
from segment_anything import sam_model_registry
|
18 |
import pickle
|
|
|
19 |
VILD_PROMPT = [
|
20 |
"a photo of a {}.",
|
21 |
"This is a photo of a {}",
|
@@ -33,6 +34,20 @@ VILD_PROMPT = [
|
|
33 |
"There is a large {} in the scene.",
|
34 |
]
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
@META_ARCH_REGISTRY.register()
|
38 |
class FrozenSeg(nn.Module):
|
@@ -132,6 +147,14 @@ class FrozenSeg(nn.Module):
|
|
132 |
|
133 |
_, self.train_num_templates, self.train_class_names = self.prepare_class_names_from_metadata(train_metadata, train_metadata)
|
134 |
self.category_overlapping_mask, self.test_num_templates, self.test_class_names = self.prepare_class_names_from_metadata(test_metadata, train_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
# sam args
|
136 |
sam_ckpt_path = {
|
137 |
'vit_t': './pretrained_checkpoint/mobile_sam.pt',
|
@@ -165,13 +188,7 @@ class FrozenSeg(nn.Module):
|
|
165 |
|
166 |
|
167 |
def prepare_class_names_from_metadata(self, metadata, train_metadata):
|
168 |
-
|
169 |
-
res = []
|
170 |
-
for x_ in x:
|
171 |
-
x_ = x_.replace(', ', ',')
|
172 |
-
x_ = x_.split(',') # there can be multiple synonyms for single class
|
173 |
-
res.append(x_)
|
174 |
-
return res
|
175 |
# get text classifier
|
176 |
try:
|
177 |
class_names = split_labels(metadata.stuff_classes) # it includes both thing and stuff
|
@@ -188,13 +205,6 @@ class FrozenSeg(nn.Module):
|
|
188 |
category_overlapping_mask = torch.tensor(
|
189 |
category_overlapping_list, dtype=torch.long)
|
190 |
|
191 |
-
def fill_all_templates_ensemble(x_=''):
|
192 |
-
res = []
|
193 |
-
for x in x_:
|
194 |
-
for template in VILD_PROMPT:
|
195 |
-
res.append(template.format(x))
|
196 |
-
return res, len(res) // len(VILD_PROMPT)
|
197 |
-
|
198 |
num_templates = []
|
199 |
templated_class_names = []
|
200 |
for x in class_names:
|
@@ -228,17 +238,44 @@ class FrozenSeg(nn.Module):
|
|
228 |
return self.train_text_classifier, self.train_num_templates
|
229 |
else:
|
230 |
if self.test_text_classifier is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
text_classifier = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
# this is needed to avoid oom, which may happen when num of class is large
|
233 |
bs = 128
|
234 |
-
for idx in range(0, len(
|
235 |
-
text_classifier.append(self.backbone.get_text_classifier(
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
self.test_text_classifier = text_classifier
|
243 |
return self.test_text_classifier, self.test_num_templates
|
244 |
|
|
|
16 |
from .modeling.transformer_decoder.frozenseg_transformer_decoder import MaskPooling, get_classification_logits
|
17 |
from segment_anything import sam_model_registry
|
18 |
import pickle
|
19 |
+
import os
|
20 |
VILD_PROMPT = [
|
21 |
"a photo of a {}.",
|
22 |
"This is a photo of a {}",
|
|
|
34 |
"There is a large {} in the scene.",
|
35 |
]
|
36 |
|
37 |
+
def split_labels(x):
|
38 |
+
res = []
|
39 |
+
for x_ in x:
|
40 |
+
x_ = x_.replace(', ', ',')
|
41 |
+
x_ = x_.split(',') # there can be multiple synonyms for single class
|
42 |
+
res.append(x_)
|
43 |
+
return res
|
44 |
+
|
45 |
+
def fill_all_templates_ensemble(x_=''):
|
46 |
+
res = []
|
47 |
+
for x in x_:
|
48 |
+
for template in VILD_PROMPT:
|
49 |
+
res.append(template.format(x))
|
50 |
+
return res, len(res) // len(VILD_PROMPT)
|
51 |
|
52 |
@META_ARCH_REGISTRY.register()
|
53 |
class FrozenSeg(nn.Module):
|
|
|
147 |
|
148 |
_, self.train_num_templates, self.train_class_names = self.prepare_class_names_from_metadata(train_metadata, train_metadata)
|
149 |
self.category_overlapping_mask, self.test_num_templates, self.test_class_names = self.prepare_class_names_from_metadata(test_metadata, train_metadata)
|
150 |
+
|
151 |
+
self.demo_all_text_embedding_cache = {}
|
152 |
+
# This consists of COCO, ADE20K, LVIS
|
153 |
+
if os.path.exists("demo_all_text_embedding_cache.pth"):
|
154 |
+
# key: str of class name, value: tensor in shape of C
|
155 |
+
self.demo_all_text_embedding_cache = torch.load("demo_all_text_embedding_cache.pth", map_location=self.device)
|
156 |
+
self.demo_all_text_embedding_cache = {k:v.to(self.device) for k,v in self.demo_all_text_embedding_cache.items()}
|
157 |
+
|
158 |
# sam args
|
159 |
sam_ckpt_path = {
|
160 |
'vit_t': './pretrained_checkpoint/mobile_sam.pt',
|
|
|
188 |
|
189 |
|
190 |
def prepare_class_names_from_metadata(self, metadata, train_metadata):
|
191 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
# get text classifier
|
193 |
try:
|
194 |
class_names = split_labels(metadata.stuff_classes) # it includes both thing and stuff
|
|
|
205 |
category_overlapping_mask = torch.tensor(
|
206 |
category_overlapping_list, dtype=torch.long)
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
num_templates = []
|
209 |
templated_class_names = []
|
210 |
for x in class_names:
|
|
|
238 |
return self.train_text_classifier, self.train_num_templates
|
239 |
else:
|
240 |
if self.test_text_classifier is None:
|
241 |
+
try:
|
242 |
+
nontemplated_class_names = split_labels(self.test_metadata.stuff_classes) # it includes both thing and stuff
|
243 |
+
except:
|
244 |
+
# this could be for insseg, where only thing_classes are available
|
245 |
+
nontemplated_class_names = split_labels(self.test_metadata.thing_classes)
|
246 |
+
|
247 |
+
text2classifier = {}
|
248 |
+
test_class_names = []
|
249 |
+
uncached_class_name = []
|
250 |
text_classifier = []
|
251 |
+
# exclude those already in cache
|
252 |
+
for class_names in nontemplated_class_names:
|
253 |
+
for class_name in class_names:
|
254 |
+
if class_name in self.demo_all_text_embedding_cache:
|
255 |
+
text2classifier[class_name] = self.demo_all_text_embedding_cache[class_name].to(self.device)
|
256 |
+
else:
|
257 |
+
test_class_names += fill_all_templates_ensemble([class_name])[0]
|
258 |
+
uncached_class_name.append(class_name)
|
259 |
+
print("Uncached texts:", len(uncached_class_name), uncached_class_name, test_class_names)
|
260 |
# this is needed to avoid oom, which may happen when num of class is large
|
261 |
bs = 128
|
262 |
+
for idx in range(0, len(test_class_names), bs):
|
263 |
+
text_classifier.append(self.backbone.get_text_classifier(test_class_names[idx:idx+bs], self.device).detach())
|
264 |
+
|
265 |
+
if len(text_classifier) > 0:
|
266 |
+
text_classifier = torch.cat(text_classifier, dim=0)
|
267 |
+
text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
|
268 |
+
text_classifier = text_classifier.reshape(text_classifier.shape[0]//len(VILD_PROMPT), len(VILD_PROMPT), text_classifier.shape[-1]).mean(1)
|
269 |
+
text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
|
270 |
+
assert text_classifier.shape[0] == len(uncached_class_name)
|
271 |
+
for idx in range(len(uncached_class_name)):
|
272 |
+
self.demo_all_text_embedding_cache[uncached_class_name[idx]] = text_classifier[idx]
|
273 |
+
text2classifier[uncached_class_name[idx]] = text_classifier[idx]
|
274 |
+
text_classifier = []
|
275 |
+
for class_names in nontemplated_class_names:
|
276 |
+
for text in class_names:
|
277 |
+
text_classifier.append(text2classifier[text].to(self.device))
|
278 |
+
text_classifier = torch.stack(text_classifier, dim=0).to(self.device)
|
279 |
self.test_text_classifier = text_classifier
|
280 |
return self.test_text_classifier, self.test_num_templates
|
281 |
|