Spaces:
Runtime error
Runtime error
JeffLiang
commited on
Commit
•
8c62972
1
Parent(s):
ba09e2c
change sam_vit_h to sam_vit_l to save memory
Browse files- app.py +1 -1
- open_vocab_seg/utils/predictor.py +9 -4
- sam_vit_l_0b3195.pth +3 -0
app.py
CHANGED
@@ -45,7 +45,7 @@ def inference(class_names, proposal_gen, granularity, input_img):
|
|
45 |
if proposal_gen == 'MaskFormer':
|
46 |
demo = VisualizationDemo(cfg)
|
47 |
elif proposal_gen == 'Segment_Anything':
|
48 |
-
demo = SAMVisualizationDemo(cfg, granularity, './
|
49 |
class_names = class_names.split(',')
|
50 |
img = read_image(input_img, format="BGR")
|
51 |
_, visualized_output = demo.run_on_image(img, class_names)
|
|
|
45 |
if proposal_gen == 'MaskFormer':
|
46 |
demo = VisualizationDemo(cfg)
|
47 |
elif proposal_gen == 'Segment_Anything':
|
48 |
+
demo = SAMVisualizationDemo(cfg, granularity, './sam_vit_l_0b3195.pth', './ovseg_clip_l_9a1909.pth')
|
49 |
class_names = class_names.split(',')
|
50 |
img = read_image(input_img, format="BGR")
|
51 |
_, visualized_output = demo.run_on_image(img, class_names)
|
open_vocab_seg/utils/predictor.py
CHANGED
@@ -150,7 +150,7 @@ class SAMVisualizationDemo(object):
|
|
150 |
|
151 |
self.parallel = parallel
|
152 |
self.granularity = granularity
|
153 |
-
sam = sam_model_registry["
|
154 |
self.predictor = SamAutomaticMaskGenerator(sam, points_per_batch=16)
|
155 |
self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
|
156 |
self.clip_model.cuda()
|
@@ -189,12 +189,17 @@ class SAMVisualizationDemo(object):
|
|
189 |
txts = [f'a photo of {cls_name}' for cls_name in class_names]
|
190 |
text = open_clip.tokenize(txts)
|
191 |
|
|
|
|
|
192 |
with torch.no_grad(), torch.cuda.amp.autocast():
|
193 |
-
image_features = self.clip_model.encode_image(imgs.cuda().half())
|
194 |
text_features = self.clip_model.encode_text(text.cuda())
|
195 |
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
196 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
198 |
class_preds = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
199 |
select_cls = torch.zeros_like(class_preds)
|
200 |
|
|
|
150 |
|
151 |
self.parallel = parallel
|
152 |
self.granularity = granularity
|
153 |
+
sam = sam_model_registry["vit_l"](checkpoint=sam_path).cuda()
|
154 |
self.predictor = SamAutomaticMaskGenerator(sam, points_per_batch=16)
|
155 |
self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
|
156 |
self.clip_model.cuda()
|
|
|
189 |
txts = [f'a photo of {cls_name}' for cls_name in class_names]
|
190 |
text = open_clip.tokenize(txts)
|
191 |
|
192 |
+
img_batches = torch.split(imgs, 32, dim=0)
|
193 |
+
|
194 |
with torch.no_grad(), torch.cuda.amp.autocast():
|
|
|
195 |
text_features = self.clip_model.encode_text(text.cuda())
|
|
|
196 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
197 |
+
image_features = []
|
198 |
+
for img_batch in img_batches:
|
199 |
+
image_feat = self.clip_model.encode_image(img_batch.cuda().half())
|
200 |
+
image_feat /= image_feat.norm(dim=-1, keepdim=True)
|
201 |
+
image_features.append(image_feat.detach())
|
202 |
+
image_features = torch.cat(image_features, dim=0)
|
203 |
class_preds = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
204 |
select_cls = torch.zeros_like(class_preds)
|
205 |
|
sam_vit_l_0b3195.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622
|
3 |
+
size 1249524607
|