Spaces:
Running
on
L4
Running
on
L4
fixed a small bug on mask
Browse files- lib/model_zoo/clip.py +1 -24
lib/model_zoo/clip.py
CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
|
|
3 |
import numpy as np
|
4 |
from functools import partial
|
5 |
from lib.model_zoo.common.get_model import register
|
|
|
6 |
|
7 |
symbol = 'clip'
|
8 |
|
@@ -104,7 +105,6 @@ class CLIPImageContextEncoder(AbstractEncoder):
|
|
104 |
assert isinstance(masks, torch.Tensor)
|
105 |
assert (len(masks.shape)==4) and (masks.shape[1]==1)
|
106 |
masks = torch.clamp(masks, 0, 1)
|
107 |
-
masked_images = images*masks
|
108 |
masks = masks.float()
|
109 |
masks = F.interpolate(masks, [224, 224], mode='bilinear')
|
110 |
if masks.sum() == masks.numel():
|
@@ -142,29 +142,6 @@ class CLIPImageContextEncoder(AbstractEncoder):
|
|
142 |
z = z * vtoken_mask.to(dtype)
|
143 |
return z
|
144 |
|
145 |
-
# def _encode_wmask(self, images, masks):
|
146 |
-
# assert isinstance(masks, torch.Tensor)
|
147 |
-
# assert (len(masks.shape)==4) and (masks.shape[1]==1)
|
148 |
-
# masks = torch.clamp(masks, 0, 1)
|
149 |
-
# masks = masks.float()
|
150 |
-
# masks = F.interpolate(masks, [224, 224], mode='bilinear')
|
151 |
-
# if masks.sum() == masks.numel():
|
152 |
-
# return self._encode(images)
|
153 |
-
|
154 |
-
# device = images.device
|
155 |
-
# dtype = images.dtype
|
156 |
-
|
157 |
-
# vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
|
158 |
-
# vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
|
159 |
-
# mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
|
160 |
-
# vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
|
161 |
-
# vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
|
162 |
-
|
163 |
-
# z = self._encode(images)
|
164 |
-
# z[:, 1:, :] = z[:, 1:, :] * vtoken_mask.to(dtype)
|
165 |
-
# z[:, 0, :] = 0
|
166 |
-
# return z
|
167 |
-
|
168 |
def encode(self, images, masks=None):
|
169 |
if masks is None:
|
170 |
return self._encode(images)
|
|
|
3 |
import numpy as np
|
4 |
from functools import partial
|
5 |
from lib.model_zoo.common.get_model import register
|
6 |
+
import torch.nn.functional as F
|
7 |
|
8 |
symbol = 'clip'
|
9 |
|
|
|
105 |
assert isinstance(masks, torch.Tensor)
|
106 |
assert (len(masks.shape)==4) and (masks.shape[1]==1)
|
107 |
masks = torch.clamp(masks, 0, 1)
|
|
|
108 |
masks = masks.float()
|
109 |
masks = F.interpolate(masks, [224, 224], mode='bilinear')
|
110 |
if masks.sum() == masks.numel():
|
|
|
142 |
z = z * vtoken_mask.to(dtype)
|
143 |
return z
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def encode(self, images, masks=None):
|
146 |
if masks is None:
|
147 |
return self._encode(images)
|