initial commit
Browse files- README.md +34 -0
- VideoCLIP-XL.bin +3 -0
- demo.py +77 -0
- modeling.py +18 -0
- requirements.txt +6 -0
- utils/__init__.py +0 -0
- utils/text_encoder/__init__.py +1 -0
- utils/text_encoder/bpe_simple_vocab_16e6.txt.gz +3 -0
- utils/text_encoder/model_text_encoder.py +395 -0
- utils/text_encoder/simple_tokenizer.py +132 -0
- utils/text_encoder/text_encoder.py +75 -0
- utils/vision_encoder/__init__.py +11 -0
- utils/vision_encoder/clip_vision.py +327 -0
- utils/vision_encoder/model_vision_encoder.py +84 -0
README.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# What's New
|
2 |
+
|
3 |
+
[2024/10] A new [VideoCLIP-XL-v2](https://huggingface.co/alibaba-pai/VideoCLIP-XL-v2) model has been released.
|
4 |
+
|
5 |
+
[2024/10] Initial commit for the [VideoCLIP-XL](https://huggingface.co/alibaba-pai/VideoCLIP-XL) model, the [VILD](https://huggingface.co/alibaba-pai/VILD) dataset, and the [LVDR](https://huggingface.co/alibaba-pai/LVDR) benchmark.
|
6 |
+
|
7 |
+
# VideoCLIP-XL (eXtra Length)
|
8 |
+
|
9 |
+
This model is proposed from [VideoCLIP-XL paper](https://arxiv.org/abs/2410.00741).
|
10 |
+
It aims to advance long description understanding for video CLIP Models.
|
11 |
+
|
12 |
+
# Install
|
13 |
+
~~~
|
14 |
+
# 1. Create your environment
|
15 |
+
# 2. Install torch
|
16 |
+
# 3. Then:
|
17 |
+
pip install -r requirements.txt
|
18 |
+
~~~
|
19 |
+
|
20 |
+
# Usage
|
21 |
+
Please refer to ```demo.py```.
|
22 |
+
|
23 |
+
# Source
|
24 |
+
~~~
|
25 |
+
@misc{wang2024videoclipxladvancinglongdescription,
|
26 |
+
title={VideoCLIP-XL: Advancing Long Description Understanding for Video CLIP Models},
|
27 |
+
author={Jiapeng Wang and Chengyu Wang and Kunzhe Huang and Jun Huang and Lianwen Jin},
|
28 |
+
year={2024},
|
29 |
+
eprint={2410.00741},
|
30 |
+
archivePrefix={arXiv},
|
31 |
+
primaryClass={cs.CL},
|
32 |
+
url={https://arxiv.org/abs/2410.00741},
|
33 |
+
}
|
34 |
+
~~~
|
VideoCLIP-XL.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:413b6adfd78dd4c0e602df1b5aaa21f3ff35d8f00993069fd51f0c924e0a0113
|
3 |
+
size 1711973262
|
demo.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from modeling import VideoCLIP_XL
|
10 |
+
from utils.text_encoder import text_encoder
|
11 |
+
|
12 |
+
|
13 |
+
def _frame_from_video(video):
|
14 |
+
while video.isOpened():
|
15 |
+
success, frame = video.read()
|
16 |
+
if success:
|
17 |
+
yield frame
|
18 |
+
else:
|
19 |
+
break
|
20 |
+
|
21 |
+
|
22 |
+
v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)
|
23 |
+
v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)
|
24 |
+
def normalize(data):
|
25 |
+
return (data / 255.0 - v_mean) / v_std
|
26 |
+
|
27 |
+
|
28 |
+
def video_preprocessing(video_path, fnum=8):
|
29 |
+
video = cv2.VideoCapture(video_path)
|
30 |
+
frames = [x for x in _frame_from_video(video)]
|
31 |
+
step = len(frames) // fnum
|
32 |
+
frames = frames[::step][:fnum]
|
33 |
+
|
34 |
+
vid_tube = []
|
35 |
+
for fr in frames:
|
36 |
+
fr = fr[:,:,::-1]
|
37 |
+
fr = cv2.resize(fr, (224, 224))
|
38 |
+
fr = np.expand_dims(normalize(fr), axis=(0, 1))
|
39 |
+
vid_tube.append(fr)
|
40 |
+
vid_tube = np.concatenate(vid_tube, axis=1)
|
41 |
+
vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
|
42 |
+
vid_tube = torch.from_numpy(vid_tube)
|
43 |
+
|
44 |
+
return vid_tube
|
45 |
+
|
46 |
+
|
47 |
+
videoclip_xl = VideoCLIP_XL()
|
48 |
+
state_dict = torch.load("./VideoCLIP-XL.bin", map_location="cpu")
|
49 |
+
videoclip_xl.load_state_dict(state_dict)
|
50 |
+
videoclip_xl.cuda().eval()
|
51 |
+
|
52 |
+
|
53 |
+
videos = [
|
54 |
+
"/path/to/video-1.mp4",
|
55 |
+
"/path/to/video-2.mp4",
|
56 |
+
]
|
57 |
+
|
58 |
+
texts = [
|
59 |
+
"text-1",
|
60 |
+
"text-2",
|
61 |
+
"text-3",
|
62 |
+
]
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
video_inputs = torch.cat([video_preprocessing(video) for video in videos], 0).float().cuda()
|
66 |
+
video_features = videoclip_xl.vision_model.get_vid_features(video_inputs).float()
|
67 |
+
video_features = video_features / video_features.norm(dim=-1, keepdim=True)
|
68 |
+
|
69 |
+
text_inputs = text_encoder.tokenize(texts, truncate=True).cuda()
|
70 |
+
text_features = videoclip_xl.text_model.encode_text(text_inputs).float()
|
71 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
72 |
+
|
73 |
+
Tmp = 100.
|
74 |
+
|
75 |
+
sim_matrix = (text_features @ video_features.T) * Tmp
|
76 |
+
|
77 |
+
print(sim_matrix)
|
modeling.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from utils.text_encoder import text_encoder
|
11 |
+
from utils.vision_encoder import get_vision_encoder
|
12 |
+
|
13 |
+
|
14 |
+
class VideoCLIP_XL(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(VideoCLIP_XL, self).__init__()
|
17 |
+
self.text_model = text_encoder.load().float()
|
18 |
+
self.vision_model = get_vision_encoder().float()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python
|
2 |
+
ftfy
|
3 |
+
regex
|
4 |
+
timm
|
5 |
+
decord
|
6 |
+
einops
|
utils/__init__.py
ADDED
File without changes
|
utils/text_encoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .text_encoder import *
|
utils/text_encoder/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
utils/text_encoder/model_text_encoder.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
37 |
+
("-1", nn.AvgPool2d(stride)),
|
38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
40 |
+
]))
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor):
|
43 |
+
identity = x
|
44 |
+
|
45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
47 |
+
out = self.avgpool(out)
|
48 |
+
out = self.bn3(self.conv3(out))
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
identity = self.downsample(x)
|
52 |
+
|
53 |
+
out += identity
|
54 |
+
out = self.relu3(out)
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
class AttentionPool2d(nn.Module):
|
59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
60 |
+
super().__init__()
|
61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
66 |
+
self.num_heads = num_heads
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
72 |
+
x, _ = F.multi_head_attention_forward(
|
73 |
+
query=x[:1], key=x, value=x,
|
74 |
+
embed_dim_to_check=x.shape[-1],
|
75 |
+
num_heads=self.num_heads,
|
76 |
+
q_proj_weight=self.q_proj.weight,
|
77 |
+
k_proj_weight=self.k_proj.weight,
|
78 |
+
v_proj_weight=self.v_proj.weight,
|
79 |
+
in_proj_weight=None,
|
80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
81 |
+
bias_k=None,
|
82 |
+
bias_v=None,
|
83 |
+
add_zero_attn=False,
|
84 |
+
dropout_p=0,
|
85 |
+
out_proj_weight=self.c_proj.weight,
|
86 |
+
out_proj_bias=self.c_proj.bias,
|
87 |
+
use_separate_proj_weight=True,
|
88 |
+
training=self.training,
|
89 |
+
need_weights=False
|
90 |
+
)
|
91 |
+
return x.squeeze(0)
|
92 |
+
|
93 |
+
|
94 |
+
class ModifiedResNet(nn.Module):
|
95 |
+
"""
|
96 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
97 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
98 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
99 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
103 |
+
super().__init__()
|
104 |
+
self.output_dim = output_dim
|
105 |
+
self.input_resolution = input_resolution
|
106 |
+
|
107 |
+
# the 3-layer stem
|
108 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
109 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
110 |
+
self.relu1 = nn.ReLU(inplace=True)
|
111 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
112 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
113 |
+
self.relu2 = nn.ReLU(inplace=True)
|
114 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
115 |
+
self.bn3 = nn.BatchNorm2d(width)
|
116 |
+
self.relu3 = nn.ReLU(inplace=True)
|
117 |
+
self.avgpool = nn.AvgPool2d(2)
|
118 |
+
|
119 |
+
# residual layers
|
120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
125 |
+
|
126 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
127 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
128 |
+
|
129 |
+
def _make_layer(self, planes, blocks, stride=1):
|
130 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
131 |
+
|
132 |
+
self._inplanes = planes * Bottleneck.expansion
|
133 |
+
for _ in range(1, blocks):
|
134 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
135 |
+
|
136 |
+
return nn.Sequential(*layers)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
def stem(x):
|
140 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
141 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
142 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
143 |
+
x = self.avgpool(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
x = x.type(self.conv1.weight.dtype)
|
147 |
+
x = stem(x)
|
148 |
+
x = self.layer1(x)
|
149 |
+
x = self.layer2(x)
|
150 |
+
x = self.layer3(x)
|
151 |
+
x = self.layer4(x)
|
152 |
+
x = self.attnpool(x)
|
153 |
+
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class LayerNorm(nn.LayerNorm):
|
158 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
159 |
+
|
160 |
+
def forward(self, x: torch.Tensor):
|
161 |
+
orig_type = x.dtype
|
162 |
+
ret = super().forward(x.type(torch.float32))
|
163 |
+
return ret.type(orig_type)
|
164 |
+
|
165 |
+
|
166 |
+
class QuickGELU(nn.Module):
|
167 |
+
def forward(self, x: torch.Tensor):
|
168 |
+
return x * torch.sigmoid(1.702 * x)
|
169 |
+
|
170 |
+
|
171 |
+
class ResidualAttentionBlock(nn.Module):
|
172 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
176 |
+
self.ln_1 = LayerNorm(d_model)
|
177 |
+
self.mlp = nn.Sequential(OrderedDict([
|
178 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
179 |
+
("gelu", QuickGELU()),
|
180 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
181 |
+
]))
|
182 |
+
self.ln_2 = LayerNorm(d_model)
|
183 |
+
self.attn_mask = attn_mask
|
184 |
+
|
185 |
+
def attention(self, x: torch.Tensor):
|
186 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
187 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor):
|
190 |
+
x = x + self.attention(self.ln_1(x))
|
191 |
+
x = x + self.mlp(self.ln_2(x))
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class Transformer(nn.Module):
|
196 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
197 |
+
super().__init__()
|
198 |
+
self.width = width
|
199 |
+
self.layers = layers
|
200 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
201 |
+
|
202 |
+
def forward(self, x: torch.Tensor):
|
203 |
+
return self.resblocks(x)
|
204 |
+
|
205 |
+
|
206 |
+
class VisionTransformer(nn.Module):
|
207 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
208 |
+
super().__init__()
|
209 |
+
self.input_resolution = input_resolution
|
210 |
+
self.output_dim = output_dim
|
211 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
212 |
+
|
213 |
+
scale = width ** -0.5
|
214 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
215 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
216 |
+
self.ln_pre = LayerNorm(width)
|
217 |
+
|
218 |
+
self.transformer = Transformer(width, layers, heads)
|
219 |
+
|
220 |
+
self.ln_post = LayerNorm(width)
|
221 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
222 |
+
|
223 |
+
def forward(self, x: torch.Tensor):
|
224 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
225 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
226 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
227 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
228 |
+
x = x + self.positional_embedding.to(x.dtype)
|
229 |
+
x = self.ln_pre(x)
|
230 |
+
|
231 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
232 |
+
x = self.transformer(x)
|
233 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
234 |
+
|
235 |
+
x = self.ln_post(x[:, 0, :])
|
236 |
+
|
237 |
+
if self.proj is not None:
|
238 |
+
x = x @ self.proj
|
239 |
+
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class CLIP(nn.Module):
|
244 |
+
def __init__(self,
|
245 |
+
embed_dim: int,
|
246 |
+
# vision
|
247 |
+
image_resolution: int,
|
248 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
249 |
+
vision_width: int,
|
250 |
+
vision_patch_size: int,
|
251 |
+
# text
|
252 |
+
context_length: int,
|
253 |
+
vocab_size: int,
|
254 |
+
transformer_width: int,
|
255 |
+
transformer_heads: int,
|
256 |
+
transformer_layers: int,
|
257 |
+
load_from_clip: bool
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
|
261 |
+
self.context_length = 248
|
262 |
+
|
263 |
+
self.transformer = Transformer(
|
264 |
+
width=transformer_width,
|
265 |
+
layers=transformer_layers,
|
266 |
+
heads=transformer_heads,
|
267 |
+
attn_mask=self.build_attention_mask()
|
268 |
+
)
|
269 |
+
|
270 |
+
self.vocab_size = vocab_size
|
271 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
272 |
+
|
273 |
+
if load_from_clip == False:
|
274 |
+
self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width))
|
275 |
+
self.positional_embedding_res = nn.Parameter(torch.empty(248, transformer_width))
|
276 |
+
|
277 |
+
else:
|
278 |
+
self.positional_embedding = nn.Parameter(torch.empty(77, transformer_width))
|
279 |
+
|
280 |
+
self.ln_final = LayerNorm(transformer_width)
|
281 |
+
|
282 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
283 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
284 |
+
|
285 |
+
self.initialize_parameters()
|
286 |
+
self.mask1 = torch.zeros([248, 1])
|
287 |
+
self.mask1[:20, :] = 1
|
288 |
+
self.mask2 = torch.zeros([248, 1])
|
289 |
+
self.mask2[20:, :] = 1
|
290 |
+
|
291 |
+
|
292 |
+
def initialize_parameters(self):
|
293 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
294 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
295 |
+
|
296 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
297 |
+
attn_std = self.transformer.width ** -0.5
|
298 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
299 |
+
for block in self.transformer.resblocks:
|
300 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
301 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
302 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
303 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
304 |
+
|
305 |
+
if self.text_projection is not None:
|
306 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
307 |
+
|
308 |
+
def build_attention_mask(self):
|
309 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
310 |
+
# pytorch uses additive attention mask; fill with -inf
|
311 |
+
mask = torch.empty(self.context_length, self.context_length)
|
312 |
+
mask.fill_(float("-inf"))
|
313 |
+
mask.triu_(1) # zero out the lower diagonal
|
314 |
+
return mask
|
315 |
+
|
316 |
+
@property
|
317 |
+
def dtype(self):
|
318 |
+
return self.token_embedding.weight.dtype
|
319 |
+
|
320 |
+
def encode_text(self, text):
|
321 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
322 |
+
|
323 |
+
x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
|
324 |
+
|
325 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
326 |
+
x = self.transformer(x)
|
327 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
328 |
+
x = self.ln_final(x).type(self.dtype)
|
329 |
+
|
330 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
331 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
332 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
333 |
+
|
334 |
+
return x
|
335 |
+
|
336 |
+
def encode_text_full(self, text):
|
337 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
338 |
+
|
339 |
+
x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device)
|
340 |
+
|
341 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
342 |
+
x = self.transformer(x)
|
343 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
344 |
+
x = self.ln_final(x).type(self.dtype)
|
345 |
+
|
346 |
+
return x
|
347 |
+
|
348 |
+
|
349 |
+
def convert_weights(model: nn.Module):
|
350 |
+
"""Convert applicable model parameters to fp16"""
|
351 |
+
|
352 |
+
def _convert_weights_to_fp16(l):
|
353 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
354 |
+
l.weight.data = l.weight.data.half()
|
355 |
+
if l.bias is not None:
|
356 |
+
l.bias.data = l.bias.data.half()
|
357 |
+
|
358 |
+
if isinstance(l, nn.MultiheadAttention):
|
359 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
360 |
+
tensor = getattr(l, attr)
|
361 |
+
if tensor is not None:
|
362 |
+
tensor.data = tensor.data.half()
|
363 |
+
|
364 |
+
for name in ["text_projection", "proj"]:
|
365 |
+
if hasattr(l, name):
|
366 |
+
attr = getattr(l, name)
|
367 |
+
if attr is not None:
|
368 |
+
attr.data = attr.data.half()
|
369 |
+
|
370 |
+
model.apply(_convert_weights_to_fp16)
|
371 |
+
|
372 |
+
|
373 |
+
def build_model(load_from_clip: bool):
|
374 |
+
|
375 |
+
vision_width = 1024
|
376 |
+
vision_layers = 24
|
377 |
+
vision_patch_size = 14
|
378 |
+
grid_size = 16
|
379 |
+
image_resolution = 224
|
380 |
+
|
381 |
+
embed_dim = 768
|
382 |
+
context_length = 248
|
383 |
+
vocab_size = 49408
|
384 |
+
transformer_width = 768
|
385 |
+
transformer_heads = 12
|
386 |
+
transformer_layers = 12
|
387 |
+
|
388 |
+
model = CLIP(
|
389 |
+
embed_dim,
|
390 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
391 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, load_from_clip
|
392 |
+
)
|
393 |
+
|
394 |
+
convert_weights(model)
|
395 |
+
return model.eval()
|
utils/text_encoder/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
utils/text_encoder/text_encoder.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Union, List
|
6 |
+
from pkg_resources import packaging
|
7 |
+
from torch import nn
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
11 |
+
|
12 |
+
from .model_text_encoder import build_model
|
13 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
14 |
+
|
15 |
+
try:
|
16 |
+
from torchvision.transforms import InterpolationMode
|
17 |
+
BICUBIC = InterpolationMode.BICUBIC
|
18 |
+
except ImportError:
|
19 |
+
BICUBIC = Image.BICUBIC
|
20 |
+
|
21 |
+
|
22 |
+
_tokenizer = _Tokenizer()
|
23 |
+
|
24 |
+
|
25 |
+
def _convert_image_to_rgb(image):
|
26 |
+
return image.convert("RGB")
|
27 |
+
|
28 |
+
|
29 |
+
def load():
|
30 |
+
model = build_model(load_from_clip = False)
|
31 |
+
|
32 |
+
return model
|
33 |
+
|
34 |
+
|
35 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
36 |
+
"""
|
37 |
+
Returns the tokenized representation of given input string(s)
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
texts : Union[str, List[str]]
|
42 |
+
An input string or a list of input strings to tokenize
|
43 |
+
|
44 |
+
context_length : int
|
45 |
+
The context length to use; all CLIP models use 77 as the context length
|
46 |
+
|
47 |
+
truncate: bool
|
48 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
49 |
+
|
50 |
+
Returns
|
51 |
+
-------
|
52 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
53 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
54 |
+
"""
|
55 |
+
if isinstance(texts, str):
|
56 |
+
texts = [texts]
|
57 |
+
|
58 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
59 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
60 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
61 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
62 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
63 |
+
else:
|
64 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
65 |
+
|
66 |
+
for i, tokens in enumerate(all_tokens):
|
67 |
+
if len(tokens) > context_length:
|
68 |
+
if truncate:
|
69 |
+
tokens = tokens[:context_length]
|
70 |
+
tokens[-1] = eot_token
|
71 |
+
else:
|
72 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
73 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
74 |
+
|
75 |
+
return result
|
utils/vision_encoder/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
|
6 |
+
from .model_vision_encoder import VisionEncoder
|
7 |
+
|
8 |
+
def get_vision_encoder():
|
9 |
+
vision_encoder = VisionEncoder()
|
10 |
+
|
11 |
+
return vision_encoder
|
utils/vision_encoder/clip_vision.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from einops import rearrange
|
9 |
+
from timm.models.layers import DropPath
|
10 |
+
from timm.models.registry import register_model
|
11 |
+
|
12 |
+
import torch.utils.checkpoint as checkpoint
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
|
17 |
+
"""
|
18 |
+
Add/Remove extra temporal_embeddings as needed.
|
19 |
+
https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
|
20 |
+
|
21 |
+
temp_embed_old: (1, num_frames_old, 1, d)
|
22 |
+
temp_embed_new: (1, num_frames_new, 1, d)
|
23 |
+
add_zero: bool, if True, add zero, else, interpolate trained embeddings.
|
24 |
+
"""
|
25 |
+
# TODO zero pad
|
26 |
+
num_frms_new = temp_embed_new.shape[1]
|
27 |
+
num_frms_old = temp_embed_old.shape[1]
|
28 |
+
logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
|
29 |
+
if num_frms_new > num_frms_old:
|
30 |
+
if add_zero:
|
31 |
+
temp_embed_new[
|
32 |
+
:, :num_frms_old
|
33 |
+
] = temp_embed_old # untrained embeddings are zeros.
|
34 |
+
else:
|
35 |
+
temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
|
36 |
+
elif num_frms_new < num_frms_old:
|
37 |
+
temp_embed_new = temp_embed_old[:, :num_frms_new]
|
38 |
+
else: # =
|
39 |
+
temp_embed_new = temp_embed_old
|
40 |
+
return temp_embed_new
|
41 |
+
|
42 |
+
|
43 |
+
class QuickGELU(nn.Module):
|
44 |
+
def forward(self, x):
|
45 |
+
return x * torch.sigmoid(1.702 * x)
|
46 |
+
|
47 |
+
|
48 |
+
class ResidualAttentionBlock(nn.Module):
|
49 |
+
def __init__(self, d_model, n_head, drop_path=0., attn_mask=None, dropout=0.):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
53 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
54 |
+
# logger.info(f'Droppath: {drop_path}')
|
55 |
+
self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout)
|
56 |
+
self.ln_1 = nn.LayerNorm(d_model)
|
57 |
+
self.mlp = nn.Sequential(OrderedDict([
|
58 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
59 |
+
("gelu", QuickGELU()),
|
60 |
+
("drop1", nn.Dropout(dropout)),
|
61 |
+
("c_proj", nn.Linear(d_model * 4, d_model)),
|
62 |
+
("drop2", nn.Dropout(dropout)),
|
63 |
+
]))
|
64 |
+
self.ln_2 = nn.LayerNorm(d_model)
|
65 |
+
self.attn_mask = attn_mask
|
66 |
+
|
67 |
+
def attention(self, x):
|
68 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
69 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = x + self.drop_path1(self.attention(self.ln_1(x)))
|
73 |
+
x = x + self.drop_path2(self.mlp(self.ln_2(x)))
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class Transformer(nn.Module):
|
78 |
+
def __init__(self, width, layers, heads, drop_path=0., checkpoint_num=0, dropout=0.):
|
79 |
+
super().__init__()
|
80 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path, layers)]
|
81 |
+
self.resblocks = nn.ModuleList()
|
82 |
+
for idx in range(layers):
|
83 |
+
self.resblocks.append(ResidualAttentionBlock(width, heads, drop_path=dpr[idx], dropout=dropout))
|
84 |
+
self.checkpoint_num = checkpoint_num
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
for idx, blk in enumerate(self.resblocks):
|
88 |
+
if idx < self.checkpoint_num:
|
89 |
+
x = checkpoint.checkpoint(blk, x)
|
90 |
+
else:
|
91 |
+
x = blk(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class VisionTransformer(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self, input_resolution, patch_size, width, layers, heads, output_dim=None,
|
98 |
+
kernel_size=1, num_frames=8, drop_path=0, checkpoint_num=0, dropout=0.,
|
99 |
+
temp_embed=True,
|
100 |
+
):
|
101 |
+
super().__init__()
|
102 |
+
self.output_dim = output_dim
|
103 |
+
self.conv1 = nn.Conv3d(
|
104 |
+
3, width,
|
105 |
+
(kernel_size, patch_size, patch_size),
|
106 |
+
(kernel_size, patch_size, patch_size),
|
107 |
+
(0, 0, 0), bias=False
|
108 |
+
)
|
109 |
+
|
110 |
+
scale = width ** -0.5
|
111 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
112 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
113 |
+
self.ln_pre = nn.LayerNorm(width)
|
114 |
+
if temp_embed:
|
115 |
+
self.temporal_positional_embedding = nn.Parameter(torch.zeros(1, num_frames, width))
|
116 |
+
|
117 |
+
self.transformer = Transformer(
|
118 |
+
width, layers, heads, drop_path=drop_path, checkpoint_num=checkpoint_num,
|
119 |
+
dropout=dropout)
|
120 |
+
|
121 |
+
self.ln_post = nn.LayerNorm(width)
|
122 |
+
if output_dim is not None:
|
123 |
+
self.proj = nn.Parameter(torch.empty(width, output_dim))
|
124 |
+
else:
|
125 |
+
self.proj = None
|
126 |
+
|
127 |
+
self.dropout = nn.Dropout(dropout)
|
128 |
+
|
129 |
+
def get_num_layers(self):
|
130 |
+
return len(self.transformer.resblocks)
|
131 |
+
|
132 |
+
@torch.jit.ignore
|
133 |
+
def no_weight_decay(self):
|
134 |
+
return {'positional_embedding', 'class_embedding', 'temporal_positional_embedding'}
|
135 |
+
|
136 |
+
def mask_tokens(self, inputs, masking_prob=0.0):
|
137 |
+
B, L, _ = inputs.shape
|
138 |
+
|
139 |
+
# This is different from text as we are masking a fix number of tokens
|
140 |
+
Lm = int(masking_prob * L)
|
141 |
+
masked_indices = torch.zeros(B, L)
|
142 |
+
indices = torch.argsort(torch.rand_like(masked_indices), dim=-1)[:, :Lm]
|
143 |
+
batch_indices = (
|
144 |
+
torch.arange(masked_indices.shape[0]).unsqueeze(-1).expand_as(indices)
|
145 |
+
)
|
146 |
+
masked_indices[batch_indices, indices] = 1
|
147 |
+
|
148 |
+
masked_indices = masked_indices.bool()
|
149 |
+
|
150 |
+
return inputs[~masked_indices].reshape(B, -1, inputs.shape[-1])
|
151 |
+
|
152 |
+
def forward(self, x, masking_prob=0.0):
|
153 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
154 |
+
B, C, T, H, W = x.shape
|
155 |
+
x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C)
|
156 |
+
|
157 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
158 |
+
x = x + self.positional_embedding.to(x.dtype)
|
159 |
+
|
160 |
+
# temporal pos
|
161 |
+
cls_tokens = x[:B, :1, :]
|
162 |
+
x = x[:, 1:]
|
163 |
+
x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)
|
164 |
+
if hasattr(self, 'temporal_positional_embedding'):
|
165 |
+
if x.size(1) == 1:
|
166 |
+
# This is a workaround for unused parameter issue
|
167 |
+
x = x + self.temporal_positional_embedding.mean(1)
|
168 |
+
else:
|
169 |
+
x = x + self.temporal_positional_embedding
|
170 |
+
x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T)
|
171 |
+
|
172 |
+
if masking_prob > 0.0:
|
173 |
+
x = self.mask_tokens(x, masking_prob)
|
174 |
+
|
175 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
176 |
+
|
177 |
+
x = self.ln_pre(x)
|
178 |
+
|
179 |
+
x = x.permute(1, 0, 2) #BND -> NBD
|
180 |
+
x = self.transformer(x)
|
181 |
+
|
182 |
+
x = self.ln_post(x)
|
183 |
+
|
184 |
+
if self.proj is not None:
|
185 |
+
x = self.dropout(x[0]) @ self.proj
|
186 |
+
else:
|
187 |
+
x = x.permute(1, 0, 2) #NBD -> BND
|
188 |
+
|
189 |
+
return x
|
190 |
+
|
191 |
+
|
192 |
+
def inflate_weight(weight_2d, time_dim, center=True):
|
193 |
+
logger.info(f'Init center: {center}')
|
194 |
+
if center:
|
195 |
+
weight_3d = torch.zeros(*weight_2d.shape)
|
196 |
+
weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
|
197 |
+
middle_idx = time_dim // 2
|
198 |
+
weight_3d[:, :, middle_idx, :, :] = weight_2d
|
199 |
+
else:
|
200 |
+
weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
|
201 |
+
weight_3d = weight_3d / time_dim
|
202 |
+
return weight_3d
|
203 |
+
|
204 |
+
|
205 |
+
def load_state_dict(model, state_dict, input_resolution=224, patch_size=16, center=True):
|
206 |
+
state_dict_3d = model.state_dict()
|
207 |
+
for k in state_dict.keys():
|
208 |
+
if k in state_dict_3d.keys() and state_dict[k].shape != state_dict_3d[k].shape:
|
209 |
+
if len(state_dict_3d[k].shape) <= 2:
|
210 |
+
logger.info(f'Ignore: {k}')
|
211 |
+
continue
|
212 |
+
logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}')
|
213 |
+
time_dim = state_dict_3d[k].shape[2]
|
214 |
+
state_dict[k] = inflate_weight(state_dict[k], time_dim, center=center)
|
215 |
+
|
216 |
+
pos_embed_checkpoint = state_dict['positional_embedding']
|
217 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
218 |
+
num_patches = (input_resolution // patch_size) ** 2
|
219 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5)
|
220 |
+
new_size = int(num_patches ** 0.5)
|
221 |
+
if orig_size != new_size:
|
222 |
+
logger.info(f'Pos_emb from {orig_size} to {new_size}')
|
223 |
+
extra_tokens = pos_embed_checkpoint[:1]
|
224 |
+
pos_tokens = pos_embed_checkpoint[1:]
|
225 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
226 |
+
pos_tokens = torch.nn.functional.interpolate(
|
227 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
228 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2)
|
229 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
|
230 |
+
state_dict['positional_embedding'] = new_pos_embed
|
231 |
+
|
232 |
+
message = model.load_state_dict(state_dict, strict=False)
|
233 |
+
logger.info(f"Load pretrained weights: {message}")
|
234 |
+
|
235 |
+
|
236 |
+
@register_model
|
237 |
+
def clip_joint_b16(
|
238 |
+
pretrained=False, input_resolution=224, kernel_size=1,
|
239 |
+
center=True, num_frames=8, drop_path=0., checkpoint_num=0,
|
240 |
+
dropout=0.,
|
241 |
+
):
|
242 |
+
model = VisionTransformer(
|
243 |
+
input_resolution=input_resolution, patch_size=16,
|
244 |
+
width=768, layers=12, heads=12, output_dim=512,
|
245 |
+
kernel_size=kernel_size, num_frames=num_frames,
|
246 |
+
drop_path=drop_path, checkpoint_num=checkpoint_num,
|
247 |
+
dropout=dropout,
|
248 |
+
)
|
249 |
+
if pretrained:
|
250 |
+
if isinstance(pretrained, str):
|
251 |
+
model_name = pretrained
|
252 |
+
else:
|
253 |
+
model_name = "ViT-B/16"
|
254 |
+
|
255 |
+
logger.info('load pretrained weights')
|
256 |
+
state_dict = torch.load(_MODELS[model_name], map_location='cpu')
|
257 |
+
load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=16, center=center)
|
258 |
+
return model.eval()
|
259 |
+
|
260 |
+
|
261 |
+
@register_model
|
262 |
+
def clip_joint_l14(
|
263 |
+
pretrained=False, input_resolution=224, kernel_size=1,
|
264 |
+
center=True, num_frames=8, drop_path=0., checkpoint_num=0,
|
265 |
+
dropout=0.,
|
266 |
+
):
|
267 |
+
model = VisionTransformer(
|
268 |
+
input_resolution=input_resolution, patch_size=14,
|
269 |
+
width=1024, layers=24, heads=16, output_dim=768,
|
270 |
+
kernel_size=kernel_size, num_frames=num_frames,
|
271 |
+
drop_path=drop_path, checkpoint_num=checkpoint_num,
|
272 |
+
dropout=dropout,
|
273 |
+
)
|
274 |
+
|
275 |
+
if pretrained:
|
276 |
+
if isinstance(pretrained, str):
|
277 |
+
model_name = pretrained
|
278 |
+
else:
|
279 |
+
model_name = "ViT-L/14"
|
280 |
+
logger.info('load pretrained weights')
|
281 |
+
state_dict = torch.load(_MODELS[model_name], map_location='cpu')
|
282 |
+
load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
|
283 |
+
return model.eval()
|
284 |
+
|
285 |
+
|
286 |
+
@register_model
|
287 |
+
def clip_joint_l14_336(
|
288 |
+
pretrained=True, input_resolution=336, kernel_size=1,
|
289 |
+
center=True, num_frames=8, drop_path=0.
|
290 |
+
):
|
291 |
+
raise NotImplementedError
|
292 |
+
model = VisionTransformer(
|
293 |
+
input_resolution=input_resolution, patch_size=14,
|
294 |
+
width=1024, layers=24, heads=16, output_dim=768,
|
295 |
+
kernel_size=kernel_size, num_frames=num_frames,
|
296 |
+
drop_path=drop_path,
|
297 |
+
)
|
298 |
+
if pretrained:
|
299 |
+
logger.info('load pretrained weights')
|
300 |
+
state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu')
|
301 |
+
load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
|
302 |
+
return model.eval()
|
303 |
+
|
304 |
+
|
305 |
+
def interpolate_pos_embed_vit(state_dict, new_model):
|
306 |
+
key = "vision_encoder.temporal_positional_embedding"
|
307 |
+
if key in state_dict:
|
308 |
+
vision_temp_embed_new = new_model.state_dict()[key]
|
309 |
+
vision_temp_embed_new = vision_temp_embed_new.unsqueeze(2) # [1, n, d] -> [1, n, 1, d]
|
310 |
+
vision_temp_embed_old = state_dict[key]
|
311 |
+
vision_temp_embed_old = vision_temp_embed_old.unsqueeze(2)
|
312 |
+
|
313 |
+
state_dict[key] = load_temp_embed_with_mismatch(
|
314 |
+
vision_temp_embed_old, vision_temp_embed_new, add_zero=False
|
315 |
+
).squeeze(2)
|
316 |
+
|
317 |
+
key = "text_encoder.positional_embedding"
|
318 |
+
if key in state_dict:
|
319 |
+
text_temp_embed_new = new_model.state_dict()[key]
|
320 |
+
text_temp_embed_new = text_temp_embed_new.unsqueeze(0).unsqueeze(2) # [n, d] -> [1, n, 1, d]
|
321 |
+
text_temp_embed_old = state_dict[key]
|
322 |
+
text_temp_embed_old = text_temp_embed_old.unsqueeze(0).unsqueeze(2)
|
323 |
+
|
324 |
+
state_dict[key] = load_temp_embed_with_mismatch(
|
325 |
+
text_temp_embed_old, text_temp_embed_new, add_zero=False
|
326 |
+
).squeeze(2).squeeze(0)
|
327 |
+
return state_dict
|
utils/vision_encoder/model_vision_encoder.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import math
|
6 |
+
|
7 |
+
from .clip_vision import clip_joint_l14, clip_joint_b16
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class VisionEncoder(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super(VisionEncoder, self).__init__()
|
16 |
+
|
17 |
+
self.vision_encoder_name = 'vit_l14'
|
18 |
+
self.vision_encoder_pretrained = False
|
19 |
+
self.inputs_image_res = 224
|
20 |
+
self.vision_encoder_kernel_size = 1
|
21 |
+
self.vision_encoder_center = True
|
22 |
+
self.video_input_num_frames = 8
|
23 |
+
self.vision_encoder_drop_path_rate = 0.1
|
24 |
+
self.vision_encoder_checkpoint_num = 24
|
25 |
+
|
26 |
+
self.vision_width = 1024
|
27 |
+
self.embed_dim = 768
|
28 |
+
self.masking_prob = 0.9
|
29 |
+
|
30 |
+
self.vision_encoder = self.build_vision_encoder()
|
31 |
+
|
32 |
+
self.temp = nn.parameter.Parameter(torch.ones([]) * 1 / 100.0)
|
33 |
+
self.temp_min = 1 / 100.0
|
34 |
+
|
35 |
+
def no_weight_decay(self):
|
36 |
+
ret = {"temp"}
|
37 |
+
ret.update(
|
38 |
+
{"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
|
39 |
+
)
|
40 |
+
|
41 |
+
return ret
|
42 |
+
|
43 |
+
|
44 |
+
def encode_vision(self, image, test=False):
|
45 |
+
if image.ndim == 5:
|
46 |
+
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
47 |
+
else:
|
48 |
+
image = image.unsqueeze(2)
|
49 |
+
|
50 |
+
if not test and self.masking_prob > 0.0:
|
51 |
+
return self.vision_encoder(
|
52 |
+
image, masking_prob=self.masking_prob
|
53 |
+
)
|
54 |
+
|
55 |
+
return self.vision_encoder(image)
|
56 |
+
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
|
60 |
+
"""Seems only used during pre-training"""
|
61 |
+
self.temp.clamp_(min=self.temp_min)
|
62 |
+
|
63 |
+
def build_vision_encoder(self):
|
64 |
+
"""build vision encoder
|
65 |
+
Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
|
66 |
+
|
67 |
+
"""
|
68 |
+
vision_encoder = clip_joint_l14(
|
69 |
+
pretrained=self.vision_encoder_pretrained,
|
70 |
+
input_resolution=self.inputs_image_res,
|
71 |
+
kernel_size=self.vision_encoder_kernel_size,
|
72 |
+
center=self.vision_encoder_center,
|
73 |
+
num_frames=self.video_input_num_frames,
|
74 |
+
drop_path=self.vision_encoder_drop_path_rate,
|
75 |
+
checkpoint_num=self.vision_encoder_checkpoint_num,
|
76 |
+
)
|
77 |
+
|
78 |
+
return vision_encoder
|
79 |
+
|
80 |
+
|
81 |
+
def get_vid_features(self, input_frames):
|
82 |
+
clip_feat = self.encode_vision(input_frames, test=True).float()
|
83 |
+
|
84 |
+
return clip_feat
|