xichen98cn commited on
Commit
e7ae87a
1 Parent(s): 5e46b82

Update frozenseg/frozenseg.py

Browse files
Files changed (1) hide show
  1. frozenseg/frozenseg.py +59 -22
frozenseg/frozenseg.py CHANGED
@@ -16,6 +16,7 @@ from segment_anything.build_sam import sam_model_registry
16
  from .modeling.transformer_decoder.frozenseg_transformer_decoder import MaskPooling, get_classification_logits
17
  from segment_anything import sam_model_registry
18
  import pickle
 
19
  VILD_PROMPT = [
20
  "a photo of a {}.",
21
  "This is a photo of a {}",
@@ -33,6 +34,20 @@ VILD_PROMPT = [
33
  "There is a large {} in the scene.",
34
  ]
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  @META_ARCH_REGISTRY.register()
38
  class FrozenSeg(nn.Module):
@@ -132,6 +147,14 @@ class FrozenSeg(nn.Module):
132
 
133
  _, self.train_num_templates, self.train_class_names = self.prepare_class_names_from_metadata(train_metadata, train_metadata)
134
  self.category_overlapping_mask, self.test_num_templates, self.test_class_names = self.prepare_class_names_from_metadata(test_metadata, train_metadata)
 
 
 
 
 
 
 
 
135
  # sam args
136
  sam_ckpt_path = {
137
  'vit_t': './pretrained_checkpoint/mobile_sam.pt',
@@ -165,13 +188,7 @@ class FrozenSeg(nn.Module):
165
 
166
 
167
  def prepare_class_names_from_metadata(self, metadata, train_metadata):
168
- def split_labels(x):
169
- res = []
170
- for x_ in x:
171
- x_ = x_.replace(', ', ',')
172
- x_ = x_.split(',') # there can be multiple synonyms for single class
173
- res.append(x_)
174
- return res
175
  # get text classifier
176
  try:
177
  class_names = split_labels(metadata.stuff_classes) # it includes both thing and stuff
@@ -188,13 +205,6 @@ class FrozenSeg(nn.Module):
188
  category_overlapping_mask = torch.tensor(
189
  category_overlapping_list, dtype=torch.long)
190
 
191
- def fill_all_templates_ensemble(x_=''):
192
- res = []
193
- for x in x_:
194
- for template in VILD_PROMPT:
195
- res.append(template.format(x))
196
- return res, len(res) // len(VILD_PROMPT)
197
-
198
  num_templates = []
199
  templated_class_names = []
200
  for x in class_names:
@@ -228,17 +238,44 @@ class FrozenSeg(nn.Module):
228
  return self.train_text_classifier, self.train_num_templates
229
  else:
230
  if self.test_text_classifier is None:
 
 
 
 
 
 
 
 
 
231
  text_classifier = []
 
 
 
 
 
 
 
 
 
232
  # this is needed to avoid oom, which may happen when num of class is large
233
  bs = 128
234
- for idx in range(0, len(self.test_class_names), bs):
235
- text_classifier.append(self.backbone.get_text_classifier(self.test_class_names[idx:idx+bs], self.device).detach())
236
- text_classifier = torch.cat(text_classifier, dim=0)
237
-
238
- # average across templates and normalization.
239
- text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
240
- text_classifier = text_classifier.reshape(text_classifier.shape[0]//len(VILD_PROMPT), len(VILD_PROMPT), text_classifier.shape[-1]).mean(1)
241
- text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
 
 
 
 
 
 
 
 
 
242
  self.test_text_classifier = text_classifier
243
  return self.test_text_classifier, self.test_num_templates
244
 
 
16
  from .modeling.transformer_decoder.frozenseg_transformer_decoder import MaskPooling, get_classification_logits
17
  from segment_anything import sam_model_registry
18
  import pickle
19
+ import os
20
  VILD_PROMPT = [
21
  "a photo of a {}.",
22
  "This is a photo of a {}",
 
34
  "There is a large {} in the scene.",
35
  ]
36
 
37
+ def split_labels(x):
38
+ res = []
39
+ for x_ in x:
40
+ x_ = x_.replace(', ', ',')
41
+ x_ = x_.split(',') # there can be multiple synonyms for single class
42
+ res.append(x_)
43
+ return res
44
+
45
+ def fill_all_templates_ensemble(x_=''):
46
+ res = []
47
+ for x in x_:
48
+ for template in VILD_PROMPT:
49
+ res.append(template.format(x))
50
+ return res, len(res) // len(VILD_PROMPT)
51
 
52
  @META_ARCH_REGISTRY.register()
53
  class FrozenSeg(nn.Module):
 
147
 
148
  _, self.train_num_templates, self.train_class_names = self.prepare_class_names_from_metadata(train_metadata, train_metadata)
149
  self.category_overlapping_mask, self.test_num_templates, self.test_class_names = self.prepare_class_names_from_metadata(test_metadata, train_metadata)
150
+
151
+ self.demo_all_text_embedding_cache = {}
152
+ # This consists of COCO, ADE20K, LVIS
153
+ if os.path.exists("demo_all_text_embedding_cache.pth"):
154
+ # key: str of class name, value: tensor in shape of C
155
+ self.demo_all_text_embedding_cache = torch.load("demo_all_text_embedding_cache.pth", map_location=self.device)
156
+ self.demo_all_text_embedding_cache = {k:v.to(self.device) for k,v in self.demo_all_text_embedding_cache.items()}
157
+
158
  # sam args
159
  sam_ckpt_path = {
160
  'vit_t': './pretrained_checkpoint/mobile_sam.pt',
 
188
 
189
 
190
  def prepare_class_names_from_metadata(self, metadata, train_metadata):
191
+
 
 
 
 
 
 
192
  # get text classifier
193
  try:
194
  class_names = split_labels(metadata.stuff_classes) # it includes both thing and stuff
 
205
  category_overlapping_mask = torch.tensor(
206
  category_overlapping_list, dtype=torch.long)
207
 
 
 
 
 
 
 
 
208
  num_templates = []
209
  templated_class_names = []
210
  for x in class_names:
 
238
  return self.train_text_classifier, self.train_num_templates
239
  else:
240
  if self.test_text_classifier is None:
241
+ try:
242
+ nontemplated_class_names = split_labels(self.test_metadata.stuff_classes) # it includes both thing and stuff
243
+ except:
244
+ # this could be for insseg, where only thing_classes are available
245
+ nontemplated_class_names = split_labels(self.test_metadata.thing_classes)
246
+
247
+ text2classifier = {}
248
+ test_class_names = []
249
+ uncached_class_name = []
250
  text_classifier = []
251
+ # exclude those already in cache
252
+ for class_names in nontemplated_class_names:
253
+ for class_name in class_names:
254
+ if class_name in self.demo_all_text_embedding_cache:
255
+ text2classifier[class_name] = self.demo_all_text_embedding_cache[class_name].to(self.device)
256
+ else:
257
+ test_class_names += fill_all_templates_ensemble([class_name])[0]
258
+ uncached_class_name.append(class_name)
259
+ print("Uncached texts:", len(uncached_class_name), uncached_class_name, test_class_names)
260
  # this is needed to avoid oom, which may happen when num of class is large
261
  bs = 128
262
+ for idx in range(0, len(test_class_names), bs):
263
+ text_classifier.append(self.backbone.get_text_classifier(test_class_names[idx:idx+bs], self.device).detach())
264
+
265
+ if len(text_classifier) > 0:
266
+ text_classifier = torch.cat(text_classifier, dim=0)
267
+ text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
268
+ text_classifier = text_classifier.reshape(text_classifier.shape[0]//len(VILD_PROMPT), len(VILD_PROMPT), text_classifier.shape[-1]).mean(1)
269
+ text_classifier /= text_classifier.norm(dim=-1, keepdim=True)
270
+ assert text_classifier.shape[0] == len(uncached_class_name)
271
+ for idx in range(len(uncached_class_name)):
272
+ self.demo_all_text_embedding_cache[uncached_class_name[idx]] = text_classifier[idx]
273
+ text2classifier[uncached_class_name[idx]] = text_classifier[idx]
274
+ text_classifier = []
275
+ for class_names in nontemplated_class_names:
276
+ for text in class_names:
277
+ text_classifier.append(text2classifier[text].to(self.device))
278
+ text_classifier = torch.stack(text_classifier, dim=0).to(self.device)
279
  self.test_text_classifier = text_classifier
280
  return self.test_text_classifier, self.test_num_templates
281