Spaces:
Running
Running
Upload 9 files
Browse files- models/__init__.py +10 -0
- models/arguments_live.py +54 -0
- models/configuration_live.py +21 -0
- models/live_llama/__init__.py +2 -0
- models/live_llama/configuration_live_llama.py +7 -0
- models/live_llama/modeling_live_llama.py +154 -0
- models/modeling_live.py +222 -0
- models/tokenization_live.py +153 -0
- models/vision_live.py +61 -0
models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import HfArgumentParser
|
2 |
+
|
3 |
+
from .arguments_live import LiveTrainingArguments, get_args_class
|
4 |
+
from .live_llama import build_live_llama as build_model_and_tokenizer
|
5 |
+
from .modeling_live import fast_greedy_generate
|
6 |
+
|
7 |
+
def parse_args() -> LiveTrainingArguments:
|
8 |
+
args, = HfArgumentParser(LiveTrainingArguments).parse_args_into_dataclasses()
|
9 |
+
args, = HfArgumentParser(get_args_class(args.live_version)).parse_args_into_dataclasses()
|
10 |
+
return args
|
models/arguments_live.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from transformers import TrainingArguments
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class LiveTrainingArguments(TrainingArguments):
|
6 |
+
live_version: str = 'live1+'
|
7 |
+
system_prompt: str = (
|
8 |
+
"A multimodal AI assistant is helping users with some activities."
|
9 |
+
" Below is their conversation, interleaved with the list of video frames received by the assistant."
|
10 |
+
)
|
11 |
+
train_datasets: list[str] = None
|
12 |
+
eval_datasets: list[str] = None
|
13 |
+
stream_loss_weight: float = 1.0
|
14 |
+
llm_pretrained: str = 'meta-llama/Meta-Llama-3-8B-Instruct'
|
15 |
+
vision_pretrained: str = 'google/siglip-large-patch16-384'
|
16 |
+
lora_modules: str = "model.*(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)|lm_head$"
|
17 |
+
lora_r: int = 128
|
18 |
+
lora_alpha: int = 256
|
19 |
+
finetune_modules: list[str] = field(default_factory=lambda: ['connector'])
|
20 |
+
frame_fps: int = 2 # for training. inference can be 10
|
21 |
+
frame_token_cls: bool = None
|
22 |
+
frame_token_pooled: list[int] = None
|
23 |
+
frame_resolution: int = 384
|
24 |
+
frame_token_interval: str = None
|
25 |
+
frame_token_interval_threshold: float = 0.0
|
26 |
+
augmentation: bool = False
|
27 |
+
attn_implementation: str = 'flash_attention_2'
|
28 |
+
output_dir: str = 'outputs/debug'
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class LiveOneTrainingArguments(LiveTrainingArguments):
|
32 |
+
live_version: str = 'live1'
|
33 |
+
frame_token_cls: bool = True
|
34 |
+
frame_num_tokens: int = 1
|
35 |
+
frame_token_interval: str = ''
|
36 |
+
embed_mark: str = '2fps_384_1'
|
37 |
+
max_num_frames: int = 7200 # 1h, 2fps, 7200 frames
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class LiveOnePlusTrainingArguments(LiveTrainingArguments):
|
41 |
+
live_version: str = 'live1+'
|
42 |
+
frame_token_cls: bool = True
|
43 |
+
frame_token_pooled: list[int] = field(default_factory=lambda: [3,3])
|
44 |
+
frame_num_tokens: int = 10 # 1+3x3
|
45 |
+
embed_mark: str = '2fps_384_1+3x3'
|
46 |
+
frame_token_interval: str = ','
|
47 |
+
max_num_frames: int = 1200 # 10min, 2fps, 1200 frames
|
48 |
+
|
49 |
+
def get_args_class(live_version: str):
|
50 |
+
if live_version == 'live1':
|
51 |
+
return LiveOneTrainingArguments
|
52 |
+
elif live_version == 'live1+':
|
53 |
+
return LiveOnePlusTrainingArguments
|
54 |
+
raise NotImplementedError
|
models/configuration_live.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import PretrainedConfig
|
3 |
+
|
4 |
+
class LiveConfigMixin(PretrainedConfig):
|
5 |
+
def __init__(self, *, vision_pretrained: str = None,
|
6 |
+
frame_resolution: int = None, frame_token_cls: bool = None, frame_token_pooled: list[int] = None, frame_num_tokens: int = None,
|
7 |
+
v_placeholder: str = '<v>', frame_token_interval: str = None, v_placeholder_id: int = None, frame_token_interval_id: int = None,
|
8 |
+
stream_loss_weight: float = 1.0, vision_hidden_size=1024, **kwargs
|
9 |
+
):
|
10 |
+
super().__init__(**kwargs)
|
11 |
+
self.vision_pretrained = vision_pretrained
|
12 |
+
self.frame_resolution = frame_resolution
|
13 |
+
self.frame_token_cls = frame_token_cls
|
14 |
+
self.frame_token_pooled = frame_token_pooled
|
15 |
+
self.frame_num_tokens = frame_num_tokens
|
16 |
+
self.vision_hidden_size = vision_hidden_size
|
17 |
+
self.stream_loss_weight = stream_loss_weight
|
18 |
+
self.v_placeholder = v_placeholder
|
19 |
+
self.frame_token_interval = frame_token_interval
|
20 |
+
self.v_placeholder_id = v_placeholder_id
|
21 |
+
self.frame_token_interval_id = frame_token_interval_id
|
models/live_llama/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .configuration_live_llama import LiveLlamaConfig
|
2 |
+
from .modeling_live_llama import LiveLlamaForCausalLM, build_live_llama
|
models/live_llama/configuration_live_llama.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import LlamaConfig
|
3 |
+
|
4 |
+
from ..configuration_live import LiveConfigMixin
|
5 |
+
|
6 |
+
class LiveLlamaConfig(LlamaConfig, LiveConfigMixin):
|
7 |
+
pass
|
models/live_llama/modeling_live_llama.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from transformers import LlamaForCausalLM, Cache
|
4 |
+
from transformers.activations import GELUActivation
|
5 |
+
from transformers.utils import logging
|
6 |
+
|
7 |
+
from .configuration_live_llama import LiveLlamaConfig
|
8 |
+
from ..modeling_live import build_live, LiveMixin
|
9 |
+
|
10 |
+
logger = logging.get_logger(__name__)
|
11 |
+
|
12 |
+
class LiveLlamaForCausalLM(LlamaForCausalLM, LiveMixin):
|
13 |
+
config_class = LiveLlamaConfig
|
14 |
+
_keys_to_ignore_on_load_missing = ['vision_encoder', 'connector']
|
15 |
+
|
16 |
+
def __init__(self, config: LiveLlamaConfig):
|
17 |
+
super().__init__(config)
|
18 |
+
self.connector = torch.nn.Sequential(
|
19 |
+
torch.nn.Linear(config.vision_hidden_size, config.hidden_size, bias=True),
|
20 |
+
GELUActivation(config.hidden_size),
|
21 |
+
torch.nn.Linear(config.hidden_size, config.hidden_size, bias=True),
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(
|
25 |
+
self,
|
26 |
+
input_ids: torch.LongTensor = None,
|
27 |
+
frames: torch.FloatTensor = None,
|
28 |
+
attention_mask: torch.Tensor = None,
|
29 |
+
position_ids: torch.LongTensor = None,
|
30 |
+
past_key_values: list[torch.FloatTensor] = None,
|
31 |
+
inputs_embeds: torch.FloatTensor = None,
|
32 |
+
labels: torch.LongTensor = None,
|
33 |
+
use_cache: bool = None,
|
34 |
+
output_attentions: bool = None,
|
35 |
+
output_hidden_states: bool = None,
|
36 |
+
return_dict: bool = None,
|
37 |
+
cache_position: torch.LongTensor = None,
|
38 |
+
**kwargs,
|
39 |
+
):
|
40 |
+
if inputs_embeds is None:
|
41 |
+
inputs_embeds = self.joint_embed(input_ids, frames)
|
42 |
+
outputs = super().forward(
|
43 |
+
attention_mask = attention_mask,
|
44 |
+
position_ids = position_ids,
|
45 |
+
past_key_values = past_key_values,
|
46 |
+
inputs_embeds = inputs_embeds,
|
47 |
+
# labels
|
48 |
+
use_cache = use_cache,
|
49 |
+
output_attentions = output_attentions,
|
50 |
+
output_hidden_states = output_hidden_states,
|
51 |
+
return_dict = return_dict,
|
52 |
+
cache_position=cache_position,
|
53 |
+
)
|
54 |
+
|
55 |
+
loss = None
|
56 |
+
if labels is not None:
|
57 |
+
logits = outputs[0]
|
58 |
+
v_mask = input_ids.flatten(0, 1) == self.config.v_placeholder_id
|
59 |
+
weight = v_mask * self.config.stream_loss_weight + ~v_mask
|
60 |
+
loss = nn.functional.cross_entropy(logits.flatten(0, 1), labels.flatten(), reduction='none') * weight
|
61 |
+
loss = loss.sum() / (labels >= 0).sum()
|
62 |
+
|
63 |
+
if not return_dict:
|
64 |
+
return (loss,) + outputs[1:] if loss is not None else outputs
|
65 |
+
|
66 |
+
outputs.loss = loss
|
67 |
+
return outputs
|
68 |
+
|
69 |
+
def prepare_inputs_for_generation(
|
70 |
+
self,
|
71 |
+
input_ids,
|
72 |
+
past_key_values=None,
|
73 |
+
attention_mask=None,
|
74 |
+
inputs_embeds=None,
|
75 |
+
cache_position=None,
|
76 |
+
use_cache=True,
|
77 |
+
**kwargs,
|
78 |
+
):
|
79 |
+
past_length = 0
|
80 |
+
if past_key_values is not None:
|
81 |
+
if isinstance(past_key_values, Cache):
|
82 |
+
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
83 |
+
max_cache_length = (
|
84 |
+
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
85 |
+
if past_key_values.get_max_length() is not None
|
86 |
+
else None
|
87 |
+
)
|
88 |
+
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
89 |
+
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
90 |
+
else:
|
91 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
92 |
+
max_cache_length = None
|
93 |
+
|
94 |
+
# Keep only the unprocessed tokens:
|
95 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
96 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
97 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
98 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
99 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
100 |
+
# input_ids based on the past_length.
|
101 |
+
elif past_length < input_ids.shape[1]:
|
102 |
+
input_ids = input_ids[:, past_length:]
|
103 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
104 |
+
|
105 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
106 |
+
if (
|
107 |
+
max_cache_length is not None
|
108 |
+
and attention_mask is not None
|
109 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
110 |
+
):
|
111 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
112 |
+
|
113 |
+
position_ids = kwargs.get("position_ids", None)
|
114 |
+
if attention_mask is not None and position_ids is None:
|
115 |
+
# create position_ids on the fly for batch generation
|
116 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
117 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
118 |
+
if past_key_values:
|
119 |
+
position_ids = position_ids[:, past_length :] # NOTE
|
120 |
+
|
121 |
+
# NOTE
|
122 |
+
if inputs_embeds is not None and past_length < inputs_embeds.size(1):
|
123 |
+
model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]}
|
124 |
+
else:
|
125 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
126 |
+
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
127 |
+
# TODO: use `next_tokens` directly instead.
|
128 |
+
model_inputs = {"input_ids": input_ids.contiguous()}
|
129 |
+
|
130 |
+
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
131 |
+
if cache_position is None:
|
132 |
+
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
133 |
+
elif use_cache:
|
134 |
+
cache_position = cache_position[-input_length:]
|
135 |
+
|
136 |
+
model_inputs.update(
|
137 |
+
{
|
138 |
+
"position_ids": position_ids, # 长度为新的inputs,从past开始
|
139 |
+
"cache_position": cache_position, # 没有被cache的区域
|
140 |
+
"past_key_values": past_key_values,
|
141 |
+
"use_cache": use_cache,
|
142 |
+
"attention_mask": attention_mask, # cache + input的长度
|
143 |
+
}
|
144 |
+
)
|
145 |
+
return model_inputs
|
146 |
+
|
147 |
+
def build_live_llama(**kwargs):
|
148 |
+
return build_live(config_class=LiveLlamaConfig, model_class=LiveLlamaForCausalLM, **kwargs)
|
149 |
+
|
150 |
+
if __name__ == '__main__':
|
151 |
+
from ..arguments_live import LiveOnePlusTrainingArguments
|
152 |
+
print(LiveOnePlusTrainingArguments().to_dict())
|
153 |
+
model, tokenizer = build_live_llama(is_training=True, **LiveOnePlusTrainingArguments().to_dict())
|
154 |
+
print(model.config, tokenizer)
|
models/modeling_live.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, os
|
2 |
+
from peft import LoraConfig, get_peft_model, PeftModel
|
3 |
+
from transformers import AutoModelForCausalLM, Cache
|
4 |
+
from transformers.utils import logging
|
5 |
+
|
6 |
+
from .tokenization_live import build_live_tokenizer_and_update_config
|
7 |
+
from .vision_live import build_live_vision
|
8 |
+
|
9 |
+
logger = logging.get_logger(__name__)
|
10 |
+
|
11 |
+
class LiveMixin(AutoModelForCausalLM):
|
12 |
+
def set_vision_inside(self):
|
13 |
+
logger.warning_once("!!! Set vision encoder in the model, only recommended for on in-the-wild inference. "
|
14 |
+
"Please dont call this for efficient training & evaluation. Instead, do visual feature pre-extraction.")
|
15 |
+
self.vision_encoder, self.vision_encode = build_live_vision(self.config)
|
16 |
+
|
17 |
+
def unset_vision_inside(self):
|
18 |
+
del self.vision_encoder
|
19 |
+
del self.vision_encode
|
20 |
+
|
21 |
+
def visual_embed(self, frames: torch.Tensor):
|
22 |
+
if hasattr(self, 'vision_encode'):
|
23 |
+
with torch.cuda.amp.autocast():
|
24 |
+
frames = self.vision_encode(self.vision_encoder, frames)
|
25 |
+
frames = frames.to(self.dtype)
|
26 |
+
frames = self.connector(frames)
|
27 |
+
return frames.view(-1, frames.shape[-1])
|
28 |
+
|
29 |
+
def joint_embed(
|
30 |
+
self,
|
31 |
+
input_ids: torch.Tensor = None,
|
32 |
+
frames: torch.Tensor = None,
|
33 |
+
):
|
34 |
+
if frames is None:
|
35 |
+
return self.get_input_embeddings()(input_ids)
|
36 |
+
if input_ids is None:
|
37 |
+
return self.visual_embed(frames)
|
38 |
+
inputs_embeds = self.get_input_embeddings()(input_ids.clamp(max=self.vocab_size-1))
|
39 |
+
v_mask = input_ids == self.config.v_placeholder_id
|
40 |
+
if v_mask.any():
|
41 |
+
inputs_embeds[v_mask] = self.visual_embed(frames)
|
42 |
+
return inputs_embeds
|
43 |
+
|
44 |
+
@torch.no_grad()
|
45 |
+
def stream_evaluate(
|
46 |
+
self,
|
47 |
+
input_ids: torch.LongTensor,
|
48 |
+
labels: torch.LongTensor,
|
49 |
+
frames: torch.ByteTensor,
|
50 |
+
ignore_token_id: int = -100,
|
51 |
+
frame_token_interval_threshold: float = 0.0,
|
52 |
+
**kwargs
|
53 |
+
):
|
54 |
+
# 0. evaluation only supports batch_size = 1
|
55 |
+
assert input_ids.size(0) == labels.size(0) == 1
|
56 |
+
input_id, label = input_ids[0], labels[0]
|
57 |
+
device = input_id.device
|
58 |
+
zero = torch.tensor(0, dtype=torch.int, device=device)
|
59 |
+
one = torch.tensor(1, dtype=torch.int, device=device)
|
60 |
+
|
61 |
+
# 1. prepare multi-turn start and stop
|
62 |
+
turn_stops = ((input_id == self.config.eos_token_id).nonzero() + 1)[:,0].tolist()
|
63 |
+
turn_starts = [0] + turn_stops[:-1]
|
64 |
+
num_turns = len(turn_starts)
|
65 |
+
|
66 |
+
# 2. forward the full input_ids and labels, get tokenwise logits and losses
|
67 |
+
outputs = self.forward(input_ids=input_ids, frames=frames, return_dict=True, use_cache=True)
|
68 |
+
logit, past_key_values = outputs.logits[0], outputs.past_key_values
|
69 |
+
|
70 |
+
# 3. compute metrics for each turn
|
71 |
+
v_placeholder_id = self.config.v_placeholder_id
|
72 |
+
use_interval = self.config.frame_token_interval_id is not None
|
73 |
+
frame_token_interval_id = self.config.frame_token_interval_id if use_interval else self.config.eos_token_id
|
74 |
+
frame_num_tokens = self.config.frame_token_cls
|
75 |
+
if self.config.frame_token_pooled:
|
76 |
+
frame_num_tokens += self.config.frame_token_pooled[0] * self.config.frame_token_pooled[1]
|
77 |
+
past_num_frames = 0
|
78 |
+
lm_ppls, frame_diffs, fluencies, lm_correctness = [], [], [], []
|
79 |
+
for r, (turn_start, turn_stop) in enumerate(zip(turn_starts, turn_stops)):
|
80 |
+
## 3.1. we only have two losses: stream loss on frame tokens, and lm loss. prepare corresponding mask according two losses
|
81 |
+
turn_label = label[turn_start:turn_stop]
|
82 |
+
turn_learn_mask = turn_label != ignore_token_id
|
83 |
+
if not turn_learn_mask.any():
|
84 |
+
continue
|
85 |
+
turn_logit = logit[turn_start:turn_stop]
|
86 |
+
turn_input_id = input_id[turn_start:turn_stop]
|
87 |
+
turn_v_mask = turn_input_id == v_placeholder_id
|
88 |
+
turn_num_frames = turn_v_mask.sum() // frame_num_tokens
|
89 |
+
turn_stream_mask = turn_v_mask & turn_learn_mask
|
90 |
+
turn_lm_mask = turn_learn_mask & ~turn_stream_mask
|
91 |
+
|
92 |
+
## 3.2 ppl, offline metric
|
93 |
+
if turn_lm_mask.any():
|
94 |
+
turn_lm_masked_logit, turn_lm_masked_label = turn_logit[turn_lm_mask], turn_label[turn_lm_mask]
|
95 |
+
lm_ppl = torch.nn.functional.cross_entropy(turn_lm_masked_logit, turn_lm_masked_label).exp()
|
96 |
+
lm_ppls.append(lm_ppl)
|
97 |
+
turn_lm_masked_wrong_mask = turn_lm_masked_logit.argmax(dim=-1) != turn_lm_masked_label
|
98 |
+
if turn_lm_masked_wrong_mask.any():
|
99 |
+
num_lm_correct_tokens = turn_lm_masked_wrong_mask.nonzero()[0,0]
|
100 |
+
else:
|
101 |
+
num_lm_correct_tokens = (~turn_lm_masked_wrong_mask).sum()
|
102 |
+
lm_correctness.append(num_lm_correct_tokens / turn_lm_masked_label.numel())
|
103 |
+
|
104 |
+
## 3.3. frame_diff (will be casted to time_diff in compute_metrics)
|
105 |
+
if turn_stream_mask.any():
|
106 |
+
## 3.3.1: reply before (at) turn_num_frames
|
107 |
+
turn_score = turn_logit.softmax(dim=-1)
|
108 |
+
turn_stream_masked_score = turn_score[turn_stream_mask]
|
109 |
+
if frame_token_interval_threshold > 0:
|
110 |
+
lower_threshold_mask = turn_stream_masked_score[:, frame_token_interval_id] < frame_token_interval_threshold
|
111 |
+
turn_stream_masked_score[lower_threshold_mask] = 0
|
112 |
+
turn_stream_masked_pred_mask = turn_stream_masked_score.argmax(dim=-1) != frame_token_interval_id
|
113 |
+
if turn_stream_masked_pred_mask.any():
|
114 |
+
frame_diff = turn_stream_mask.sum() - turn_stream_masked_pred_mask.nonzero()[0,0] - 1
|
115 |
+
else:
|
116 |
+
## 3.3.2: the most complex part,reply after turn_num_frames. we assume the 'assistant: ...' not exists
|
117 |
+
turn_last_stream_idx = turn_stream_mask.nonzero()[-1,0]
|
118 |
+
past_key_values_before_assistant = self.trim_past_key_values(past_key_values, 0, turn_start + turn_last_stream_idx + 1)
|
119 |
+
if r == num_turns - 1: # no future frame. we assume the model should receive a signal when streaming ends (e.g. close button).
|
120 |
+
frame_diff = zero
|
121 |
+
else:
|
122 |
+
next_turn_num_frames = (input_id[turn_starts[r+1]:turn_stops[r+1]] == v_placeholder_id).sum() // frame_num_tokens
|
123 |
+
to_append_num_frames = min(next_turn_num_frames, turn_num_frames - 1) # avoid bias. current as center, two equal left/right side
|
124 |
+
if to_append_num_frames == 0:
|
125 |
+
frame_diff = zero
|
126 |
+
else:
|
127 |
+
to_append_frames = frames[past_num_frames+turn_num_frames:past_num_frames+turn_num_frames+to_append_num_frames]
|
128 |
+
frame_placeholder = [v_placeholder_id] * frame_num_tokens
|
129 |
+
if use_interval:
|
130 |
+
frame_placeholder = [frame_token_interval_id] + frame_placeholder
|
131 |
+
to_append_input_id = torch.tensor(frame_placeholder * to_append_num_frames, dtype=torch.long, device=device)
|
132 |
+
to_append_logit = self.forward(
|
133 |
+
input_ids=to_append_input_id[None],
|
134 |
+
past_key_values=past_key_values_before_assistant,
|
135 |
+
frames=to_append_frames,
|
136 |
+
return_dict=True, use_cache=True
|
137 |
+
).logits[0]
|
138 |
+
# we only use the last idx of each frame
|
139 |
+
idxs = torch.arange(len(frame_placeholder)-1, len(to_append_input_id), len(frame_placeholder), device=device)
|
140 |
+
to_append_score = to_append_logit[idxs].softmax(dim=-1)
|
141 |
+
if frame_token_interval_threshold > 0:
|
142 |
+
lower_threshold_mask = to_append_score[:, frame_token_interval_id] < frame_token_interval_threshold
|
143 |
+
to_append_score[lower_threshold_mask] = 0
|
144 |
+
to_append_score_pred_mask = to_append_score.argmax(dim=-1) != frame_token_interval_id
|
145 |
+
if to_append_score_pred_mask.any():
|
146 |
+
frame_diff = -(to_append_score_pred_mask.nonzero()[0,0] + 1)
|
147 |
+
else:
|
148 |
+
frame_diff = -to_append_num_frames
|
149 |
+
frame_diffs.append(frame_diff.abs())
|
150 |
+
|
151 |
+
## 2.6 fluency
|
152 |
+
if turn_lm_mask.any() and turn_stream_mask.any():
|
153 |
+
num_learn_v_tokens = turn_stream_mask.sum()
|
154 |
+
num_learn_valid_tokens = turn_lm_masked_label.numel() + num_learn_v_tokens
|
155 |
+
if frame_diff == 0:
|
156 |
+
fluency = (num_learn_v_tokens + num_lm_correct_tokens) / num_learn_valid_tokens
|
157 |
+
elif frame_diff > 0:
|
158 |
+
fluency = (num_learn_v_tokens - frame_diff) / num_learn_valid_tokens
|
159 |
+
else:
|
160 |
+
fluency = (num_learn_v_tokens - 1) / num_learn_valid_tokens
|
161 |
+
fluencies.append(fluency)
|
162 |
+
## 2.7 next turn
|
163 |
+
past_num_frames += turn_num_frames
|
164 |
+
lm_ppl = torch.stack(lm_ppls).mean() if lm_ppls else one
|
165 |
+
frame_diff = torch.stack(frame_diffs).float().mean() if frame_diffs else zero
|
166 |
+
fluency = torch.stack(fluencies).float().mean() if fluencies else one
|
167 |
+
lm_correctness = torch.stack(lm_correctness).float().mean() if lm_correctness else one
|
168 |
+
return torch.stack([lm_ppl, frame_diff, fluency, lm_correctness])
|
169 |
+
|
170 |
+
def trim_past_key_values(self, past_key_values, start, stop):
|
171 |
+
return [[past_keys[:,:,start:stop], past_values[:,:,start:stop]] for past_keys, past_values in past_key_values]
|
172 |
+
|
173 |
+
def fast_greedy_generate(*, model: LiveMixin, inputs_embeds: torch.Tensor, past_key_values: Cache, eos_token_id: int, inplace_output_ids: torch.Tensor):
|
174 |
+
for i in range(inplace_output_ids.size(1)):
|
175 |
+
outputs = model(inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=True)
|
176 |
+
past_key_values = outputs.past_key_values
|
177 |
+
new_token_id = outputs.logits[:, -1:].argmax(dim=-1)
|
178 |
+
inplace_output_ids[:, i] = new_token_id
|
179 |
+
if new_token_id == eos_token_id:
|
180 |
+
break
|
181 |
+
inputs_embeds = model.get_input_embeddings()(new_token_id)
|
182 |
+
return inplace_output_ids[:, :i+1], past_key_values
|
183 |
+
|
184 |
+
def build_live(
|
185 |
+
*,
|
186 |
+
is_training: bool,
|
187 |
+
config_class: type,
|
188 |
+
model_class: type,
|
189 |
+
llm_pretrained: str = None,
|
190 |
+
finetune_modules: list[str] = None,
|
191 |
+
lora_modules: str = None,
|
192 |
+
lora_r: int = None,
|
193 |
+
lora_alpha: int = None,
|
194 |
+
set_vision_inside: bool = False,
|
195 |
+
resume_from_checkpoint: str = '',
|
196 |
+
attn_implementation: str = 'flash_attention_2',
|
197 |
+
torch_dtype: str | torch.dtype = 'auto',
|
198 |
+
**kwargs
|
199 |
+
):
|
200 |
+
model = model_class.from_pretrained(llm_pretrained, config=config_class.from_pretrained(llm_pretrained, **kwargs), torch_dtype=torch_dtype, attn_implementation=attn_implementation)
|
201 |
+
tokenizer = build_live_tokenizer_and_update_config(llm_pretrained, model.config)
|
202 |
+
if is_training:
|
203 |
+
lora_config = LoraConfig(
|
204 |
+
r=lora_r,
|
205 |
+
lora_alpha=lora_alpha,
|
206 |
+
target_modules=lora_modules,
|
207 |
+
lora_dropout=0.05,
|
208 |
+
task_type="CAUSAL_LM",
|
209 |
+
modules_to_save=finetune_modules,
|
210 |
+
inference_mode=False,
|
211 |
+
)
|
212 |
+
model = get_peft_model(model, lora_config)
|
213 |
+
model.print_trainable_parameters()
|
214 |
+
else:
|
215 |
+
if resume_from_checkpoint and os.path.exists(resume_from_checkpoint):
|
216 |
+
model = PeftModel.from_pretrained(model, resume_from_checkpoint, is_trainable=False)
|
217 |
+
else:
|
218 |
+
logger.warning(f'!!! Fail to load checkpoint: {resume_from_checkpoint}. Return a new initialized model.')
|
219 |
+
if set_vision_inside:
|
220 |
+
model.set_vision_inside()
|
221 |
+
model.requires_grad_(False)
|
222 |
+
return model, tokenizer
|
models/tokenization_live.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
from .configuration_live import LiveConfigMixin
|
6 |
+
|
7 |
+
def get_stream_placeholder_len(num_frames: int, model_config: LiveConfigMixin) -> str:
|
8 |
+
return num_frames * model_config.frame_num_tokens * len(model_config.v_placeholder) + len(model_config.frame_token_interval) * (num_frames - 1)
|
9 |
+
|
10 |
+
def get_stream_placeholder_jinja2(model_config: LiveConfigMixin) -> str:
|
11 |
+
return f"'{model_config.frame_token_interval}'.join([{model_config.frame_num_tokens} * '{model_config.v_placeholder}'] * message['num_frames'])"
|
12 |
+
|
13 |
+
def get_stream_learn_ranges(num_frames: int, model_config: LiveConfigMixin) -> torch.Tensor:
|
14 |
+
len_frame_placeholder_with_interval = model_config.frame_num_tokens * len(model_config.v_placeholder) + len(model_config.frame_token_interval)
|
15 |
+
intermediate_interval_idxs = torch.arange(
|
16 |
+
len_frame_placeholder_with_interval,
|
17 |
+
len_frame_placeholder_with_interval * num_frames + 1,
|
18 |
+
len_frame_placeholder_with_interval
|
19 |
+
) - len(model_config.frame_token_interval)
|
20 |
+
len_learn = len(model_config.frame_token_interval) if model_config.frame_token_interval else len(model_config.v_placeholder)
|
21 |
+
learn_ranges = torch.stack([
|
22 |
+
intermediate_interval_idxs,
|
23 |
+
intermediate_interval_idxs + len_learn
|
24 |
+
], dim=1)
|
25 |
+
return learn_ranges
|
26 |
+
|
27 |
+
def chat_template(self, stream_placeholder_jinja2: str):
|
28 |
+
"""
|
29 |
+
system prompt
|
30 |
+
[<v>,<v>,<v>]
|
31 |
+
User: ...
|
32 |
+
Assistant: ...</s>
|
33 |
+
[<v>,<v>]
|
34 |
+
Assistant: ...</s>
|
35 |
+
User: ...
|
36 |
+
Assistant: ...</s>
|
37 |
+
"""
|
38 |
+
template = (
|
39 |
+
"{% if messages[0]['role'] == 'system' %}"
|
40 |
+
"{{ bos_token + messages[0]['content'] + '\n' }}" # system
|
41 |
+
"{% set messages = messages[1:] %}"
|
42 |
+
"{% endif %}"
|
43 |
+
"{% for message in messages %}"
|
44 |
+
"{% if message['role'] == 'user' %}"
|
45 |
+
"{% if add_stream_query_prompt %}"
|
46 |
+
"{{ ']\nUser: ' + message['content'] }}"
|
47 |
+
"{% else %}"
|
48 |
+
"{{ '\nUser: ' + message['content'] }}"
|
49 |
+
"{% endif %}"
|
50 |
+
"{% elif message['role'] == 'assistant' %}"
|
51 |
+
"{{ '\nAssistant: ' + message['content'] + eos_token }}"
|
52 |
+
"{% elif message['role'] == 'stream' and message['num_frames'] > 0: %}"
|
53 |
+
"{{ '\n[' + STREAM_PLACEHOLDER + ']' }}"
|
54 |
+
"{% endif %}"
|
55 |
+
"{% endfor %}"
|
56 |
+
"{% if add_generation_prompt %}"
|
57 |
+
"{{ '\nAssistant:' }}"
|
58 |
+
"{% elif add_stream_prompt %}"
|
59 |
+
"{{ '\n[' }}"
|
60 |
+
"{% elif add_stream_generation_prompt %}"
|
61 |
+
"{{ ']\nAssistant:' }}"
|
62 |
+
"{% endif %}"
|
63 |
+
)
|
64 |
+
template = template.replace('STREAM_PLACEHOLDER', stream_placeholder_jinja2)
|
65 |
+
return template
|
66 |
+
|
67 |
+
def chat_template_transition(tokenizer):
|
68 |
+
return {
|
69 |
+
(None, 'system'): tokenizer.bos_token,
|
70 |
+
('system', 'user'): '\n\nUser: ',
|
71 |
+
('system', 'stream'): '\n\n[',
|
72 |
+
('user', 'assistant'): '\nAssistant: ',
|
73 |
+
('user', 'stream'): '\n[',
|
74 |
+
('user', 'user'): '\nUser: ',
|
75 |
+
('assistant', 'user'): f'{tokenizer.eos_token}\nUser: ',
|
76 |
+
('assistant', 'stream'): f'{tokenizer.eos_token}\n[',
|
77 |
+
('stream', 'user'): ']\nUser: ',
|
78 |
+
('stream', 'assistant'): ']\nAssistant: ',
|
79 |
+
'assistant': 'Assistant: ',
|
80 |
+
'eos_token': tokenizer.eos_token,
|
81 |
+
}
|
82 |
+
|
83 |
+
def chat_template_offsets(tokenizer):
|
84 |
+
return {k:len(v) for k, v in chat_template_transition(tokenizer).items()}
|
85 |
+
|
86 |
+
def get_learn_ranges(conversation: list[dict], *, chat_template_offsets: dict[tuple, int], model_config: LiveConfigMixin):
|
87 |
+
offset = 0
|
88 |
+
learn_ranges = []
|
89 |
+
last_role = None
|
90 |
+
for message in conversation:
|
91 |
+
role = message['role']
|
92 |
+
offset += chat_template_offsets[(last_role, role)]
|
93 |
+
last_role = role
|
94 |
+
if role == 'stream':
|
95 |
+
if message.get('learn', False):
|
96 |
+
ranges = get_stream_learn_ranges(message['num_frames'], model_config) + offset
|
97 |
+
# the last one has ]\n, should also consider \n
|
98 |
+
ranges[-1, 1] += 1
|
99 |
+
if not isinstance(message['learn'], bool):
|
100 |
+
ranges = ranges[:message['learn']]
|
101 |
+
learn_ranges.extend([range(r[0], r[1]) for r in ranges])
|
102 |
+
offset += get_stream_placeholder_len(message['num_frames'], model_config)
|
103 |
+
else:
|
104 |
+
if role == 'assistant':
|
105 |
+
if message.get('learn', False):
|
106 |
+
learn_ranges.append(range(offset - chat_template_offsets['assistant'], offset + len(message['content']) + chat_template_offsets['eos_token']))
|
107 |
+
offset += len(message['content'])
|
108 |
+
return learn_ranges
|
109 |
+
|
110 |
+
def build_live_tokenizer_and_update_config(llm_pretrained: str, model_config: LiveConfigMixin) -> AutoTokenizer:
|
111 |
+
tokenizer = AutoTokenizer.from_pretrained(llm_pretrained, use_fast=True, padding_side='left')
|
112 |
+
tokenizer.add_special_tokens({'additional_special_tokens': [model_config.v_placeholder]})
|
113 |
+
v_placeholder_id = len(tokenizer) - 1
|
114 |
+
if model_config.frame_token_interval:
|
115 |
+
frame_token_interval_id = tokenizer.convert_tokens_to_ids(model_config.frame_token_interval)
|
116 |
+
else:
|
117 |
+
frame_token_interval_id = None
|
118 |
+
tokenizer.pad_token = tokenizer.eos_token
|
119 |
+
model_config.update(dict(v_placeholder_id=v_placeholder_id, frame_token_interval_id=frame_token_interval_id, eos_token_id=tokenizer.eos_token_id))
|
120 |
+
tokenizer.chat_template = chat_template(tokenizer, get_stream_placeholder_jinja2(model_config))
|
121 |
+
tokenizer.get_learn_ranges = partial(get_learn_ranges, chat_template_offsets=chat_template_offsets(tokenizer), model_config=model_config)
|
122 |
+
return tokenizer
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
config = LiveConfigMixin(frame_token_interval=',', frame_token_cls=True, frame_token_pooled=[3,3], frame_num_tokens=10)
|
126 |
+
tokenizer = build_live_tokenizer_and_update_config('meta-llama/Meta-Llama-3-8B-Instruct', config)
|
127 |
+
chat = [
|
128 |
+
{'role': 'system', 'content': 'cool.'},
|
129 |
+
{'role': 'stream', 'num_frames': 2, 'learn': 1},
|
130 |
+
{'role': 'user', 'content': 'cool?'},
|
131 |
+
{'role': 'assistant', 'content': 'cool.', 'learn': True},
|
132 |
+
{'role': 'stream', 'num_frames': 3, 'learn': 3},
|
133 |
+
{'role': 'assistant', 'content': 'so cool.', 'learn': True},
|
134 |
+
]
|
135 |
+
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
|
136 |
+
learn_ranges = tokenizer.get_learn_ranges(chat)
|
137 |
+
batch = tokenizer([prompt], return_offsets_mapping=True, add_special_tokens=False, return_tensors="pt", padding=True)
|
138 |
+
batch_labels = torch.full_like(batch.input_ids, -100, dtype=torch.long)
|
139 |
+
for text, labels, input_ids, offset_mapping, learn_range in zip(
|
140 |
+
[prompt], batch_labels, batch.input_ids, batch.offset_mapping, [learn_ranges]
|
141 |
+
):
|
142 |
+
for learn_r in learn_range:
|
143 |
+
start = torch.nonzero(offset_mapping[:,0] == learn_r.start).item()
|
144 |
+
if offset_mapping[:,0][-1] >= learn_r.stop:
|
145 |
+
stop = torch.nonzero(offset_mapping[:,0] == learn_r.stop).item()
|
146 |
+
else: # the last eos token
|
147 |
+
stop = len(input_ids)
|
148 |
+
labels[start-1:stop-1] = input_ids[start:stop]
|
149 |
+
# NOTE: input_ids may out of boundary of len(tokenizer) - 1. (1 is the added vision placeholder)
|
150 |
+
# this is because some frames has v_placeholder_id target. so replace it with eos token.
|
151 |
+
labels[labels >= len(tokenizer) - 1] = tokenizer.eos_token_id
|
152 |
+
print(batch.input_ids)
|
153 |
+
print(batch_labels)
|
models/vision_live.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math, torch
|
2 |
+
from functools import partial
|
3 |
+
from torch import nn, Tensor
|
4 |
+
from torchvision.transforms.functional import normalize
|
5 |
+
from transformers import AutoModel
|
6 |
+
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
7 |
+
|
8 |
+
from .configuration_live import LiveConfigMixin
|
9 |
+
|
10 |
+
def _siglip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple,
|
11 |
+
mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], rescale_factor=0.00392156862745098, **kwargs):
|
12 |
+
frames = normalize(frames * rescale_factor, mean=mean, std=std)
|
13 |
+
with torch.cuda.amp.autocast():
|
14 |
+
vision_outputs = vision_model(frames)
|
15 |
+
last_hidden_state = vision_outputs.last_hidden_state
|
16 |
+
if frame_token_pooled:
|
17 |
+
s = int(math.sqrt(last_hidden_state.shape[1]))
|
18 |
+
spatial_tokens = torch.nn.functional.adaptive_avg_pool2d(
|
19 |
+
last_hidden_state.reshape(
|
20 |
+
last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1]
|
21 |
+
).permute(0, 3, 1, 2),
|
22 |
+
frame_token_pooled
|
23 |
+
).flatten(2, 3).permute(0, 2, 1)
|
24 |
+
if not frame_token_cls:
|
25 |
+
return spatial_tokens
|
26 |
+
if frame_token_cls:
|
27 |
+
cls_token = vision_outputs.pooler_output[:, None]
|
28 |
+
if not frame_token_pooled:
|
29 |
+
return cls_token
|
30 |
+
return torch.cat([cls_token, spatial_tokens], dim=1)
|
31 |
+
|
32 |
+
def _clip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple,
|
33 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, rescale_factor=0.00392156862745098, **kwargs):
|
34 |
+
frames = normalize(frames * rescale_factor, mean=mean, std=std)
|
35 |
+
with torch.cuda.amp.autocast():
|
36 |
+
vision_outputs = vision_model(frames)
|
37 |
+
last_hidden_state = vision_outputs.last_hidden_state
|
38 |
+
if frame_token_pooled:
|
39 |
+
s = int(math.sqrt(last_hidden_state.shape[1]))
|
40 |
+
spatial_tokens = torch.nn.functional.adaptive_avg_pool2d(
|
41 |
+
last_hidden_state[:,1:].reshape(
|
42 |
+
last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1]
|
43 |
+
).permute(0, 3, 1, 2),
|
44 |
+
frame_token_pooled
|
45 |
+
).flatten(2, 3).permute(0, 2, 1)
|
46 |
+
if not frame_token_cls:
|
47 |
+
return spatial_tokens
|
48 |
+
if frame_token_cls:
|
49 |
+
cls_token = last_hidden_state[:,0]
|
50 |
+
if not frame_token_pooled:
|
51 |
+
return cls_token
|
52 |
+
return torch.cat([cls_token, spatial_tokens], dim=1)
|
53 |
+
|
54 |
+
def build_live_vision(config: LiveConfigMixin):
|
55 |
+
model = AutoModel.from_pretrained(config.vision_pretrained).vision_model
|
56 |
+
if 'google/siglip-large-patch16-384' == config.vision_pretrained:
|
57 |
+
return model, partial(_siglip_vision_encode, frame_token_cls=config.frame_token_cls, frame_token_pooled=config.frame_token_pooled)
|
58 |
+
elif 'laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90k' == config.vision_pretrained or 'openai/clip-vit-large-patch14-336' == config.vision_pretrained:
|
59 |
+
return model, partial(_clip_vision_encode, config)
|
60 |
+
else:
|
61 |
+
raise ValueError(f'Unverified vision_pretrained: {config.vision_pretrained}')
|