Spaces:
Sleeping
Sleeping
from typing import Optional | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
import numpy as np | |
from dataclasses import dataclass | |
from .transformer import ( | |
LayerNormFp32, | |
LayerNorm, | |
QuickGELU, | |
MultimodalTransformer, | |
) | |
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower | |
try: | |
from transformers import ( | |
BeamSearchScorer, | |
LogitsProcessorList, | |
TopPLogitsWarper, | |
TopKLogitsWarper, | |
RepetitionPenaltyLogitsProcessor, | |
MinLengthLogitsProcessor, | |
MaxLengthCriteria, | |
StoppingCriteriaList | |
) | |
GENERATION_TYPES = { | |
"top_k": TopKLogitsWarper, | |
"top_p": TopPLogitsWarper, | |
"beam_search": "beam_search" | |
} | |
_has_transformers = True | |
except ImportError as e: | |
GENERATION_TYPES = { | |
"top_k": None, | |
"top_p": None, | |
"beam_search": "beam_search" | |
} | |
_has_transformers = False | |
class MultimodalCfg(CLIPTextCfg): | |
mlp_ratio: int = 4 | |
dim_head: int = 64 | |
heads: int = 8 | |
n_queries: int = 256 | |
attn_pooler_heads: int = 8 | |
def _build_text_decoder_tower( | |
embed_dim, | |
multimodal_cfg, | |
quick_gelu: bool = False, | |
cast_dtype: Optional[torch.dtype] = None, | |
): | |
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg | |
act_layer = QuickGELU if quick_gelu else nn.GELU | |
norm_layer = ( | |
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm | |
) | |
decoder = MultimodalTransformer( | |
context_length=multimodal_cfg.context_length, | |
width=multimodal_cfg.width, | |
heads=multimodal_cfg.heads, | |
layers=multimodal_cfg.layers, | |
ls_init_value=multimodal_cfg.ls_init_value, | |
output_dim=embed_dim, | |
act_layer=act_layer, | |
norm_layer=norm_layer, | |
) | |
return decoder | |
class CoCa(nn.Module): | |
def __init__( | |
self, | |
embed_dim, | |
multimodal_cfg: MultimodalCfg, | |
text_cfg: CLIPTextCfg, | |
vision_cfg: CLIPVisionCfg, | |
quick_gelu: bool = False, | |
cast_dtype: Optional[torch.dtype] = None, | |
pad_id: int = 0, | |
): | |
super().__init__() | |
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg | |
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg | |
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg | |
self.text = _build_text_tower( | |
embed_dim=embed_dim, | |
text_cfg=text_cfg, | |
quick_gelu=quick_gelu, | |
cast_dtype=cast_dtype, | |
) | |
vocab_size = ( | |
text_cfg.vocab_size # for hf models | |
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None | |
else text_cfg.vocab_size | |
) | |
self.visual = _build_vision_tower( | |
embed_dim=embed_dim, | |
vision_cfg=vision_cfg, | |
quick_gelu=quick_gelu, | |
cast_dtype=cast_dtype, | |
) | |
self.text_decoder = _build_text_decoder_tower( | |
vocab_size, | |
multimodal_cfg=multimodal_cfg, | |
quick_gelu=quick_gelu, | |
cast_dtype=cast_dtype, | |
) | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.pad_id = pad_id | |
def set_grad_checkpointing(self, enable=True): | |
self.visual.set_grad_checkpointing(enable) | |
self.text.set_grad_checkpointing(enable) | |
self.text_decoder.set_grad_checkpointing(enable) | |
def _encode_image(self, images, normalize=True): | |
image_latent, tokens_embs = self.visual(images) | |
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent | |
return image_latent, tokens_embs | |
def _encode_text(self, text, normalize=True, embed_cls=True): | |
text = text[:, :-1] if embed_cls else text # make space for CLS token | |
text_latent, token_emb = self.text(text) | |
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent | |
return text_latent, token_emb | |
def encode_image(self, images, normalize=True): | |
image_latent, _ = self._encode_image(images, normalize=normalize) | |
return image_latent | |
def encode_text(self, text, normalize=True, embed_cls=True): | |
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) | |
return text_latent | |
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): | |
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) | |
if image_latent is None or image_embs is None: | |
image_latent, image_embs = self._encode_image(image) | |
# TODO: add assertion to avoid bugs? | |
labels = text[:, -token_embs.shape[1]:] | |
logits = self.text_decoder(image_embs, token_embs) | |
return { | |
"image_features": image_latent, | |
"text_features": text_latent, | |
"logits": logits, | |
"labels": labels, | |
"logit_scale": self.logit_scale.exp() | |
} | |
def generate( | |
self, | |
image, | |
text=None, | |
seq_len=30, | |
max_seq_len=77, | |
temperature=1., | |
generation_type="beam_search", | |
top_p=0.1, # keep tokens in the 1 - top_p quantile | |
top_k=1, # keeps the top_k most probable tokens | |
pad_token_id=None, | |
eos_token_id=None, | |
sot_token_id=None, | |
num_beams=6, | |
num_beam_groups=3, | |
min_seq_len=5, | |
stopping_criteria=None, | |
repetition_penalty=1.0, | |
fixed_output_length=False # if True output.shape == (batch_size, seq_len) | |
): | |
# taking many ideas and components from HuggingFace GenerationMixin | |
# https://huggingface.co./docs/transformers/main/en/main_classes/text_generation | |
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." | |
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" | |
with torch.no_grad(): | |
sot_token_id = 49406 if sot_token_id is None else sot_token_id | |
eos_token_id = 49407 if eos_token_id is None else eos_token_id | |
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id | |
logit_processor = LogitsProcessorList( | |
[ | |
MinLengthLogitsProcessor(min_seq_len, eos_token_id), | |
RepetitionPenaltyLogitsProcessor(repetition_penalty), | |
] | |
) | |
if stopping_criteria is None: | |
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] | |
stopping_criteria = StoppingCriteriaList( | |
stopping_criteria | |
) | |
device = image.device | |
if generation_type == "beam_search": | |
output = self._generate_beamsearch( | |
image_inputs = image, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
sot_token_id=sot_token_id, | |
num_beams=num_beams, | |
num_beam_groups=num_beam_groups, | |
min_seq_len=min_seq_len, | |
stopping_criteria=stopping_criteria, | |
logit_processor=logit_processor, | |
) | |
if fixed_output_length and output.shape[1] < seq_len: | |
return torch.cat( | |
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), | |
dim=1 | |
) | |
return output | |
elif generation_type == "top_p": | |
logit_warper = GENERATION_TYPES[generation_type](top_p) | |
elif generation_type == "top_k": | |
logit_warper = GENERATION_TYPES[generation_type](top_k) | |
else: | |
raise ValueError( | |
f"generation_type has to be one of " | |
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." | |
) | |
image_latent, image_embs = self._encode_image(image) | |
if text is None: | |
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id | |
was_training = self.training | |
num_dims = len(text.shape) | |
if num_dims == 1: | |
text = text[None, :] | |
cur_len = text.shape[1] | |
self.eval() | |
out = text | |
while True: | |
x = out[:, -max_seq_len:] | |
cur_len = x.shape[1] | |
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] | |
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) | |
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id | |
if mask.all(): | |
if not fixed_output_length: | |
break | |
else: | |
logits = logits[~mask, :] | |
filtered_logits = logit_processor(x[~mask, :], logits) | |
filtered_logits = logit_warper(x[~mask, :], filtered_logits) | |
probs = F.softmax(filtered_logits / temperature, dim=-1) | |
if (cur_len + 1 == seq_len): | |
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id | |
else: | |
sample[~mask, :] = torch.multinomial(probs, 1) | |
out = torch.cat((out, sample), dim=-1) | |
cur_len += 1 | |
if stopping_criteria(out, None): | |
break | |
if num_dims == 1: | |
out = out.squeeze(0) | |
self.train(was_training) | |
return out | |
def _generate_beamsearch( | |
self, | |
image_inputs, | |
pad_token_id=None, | |
eos_token_id=None, | |
sot_token_id=None, | |
num_beams=6, | |
num_beam_groups=3, | |
min_seq_len=5, | |
stopping_criteria=None, | |
logit_processor=None, | |
logit_warper=None, | |
): | |
device = image_inputs.device | |
batch_size = image_inputs.shape[0] | |
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) | |
image_latent, image_embs = self._encode_image(image_inputs) | |
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) | |
input_ids = input_ids * sot_token_id | |
beam_scorer = BeamSearchScorer( | |
batch_size=batch_size, | |
num_beams=num_beams, | |
device=device, | |
num_beam_groups=num_beam_groups, | |
) | |
# instantiate logits processors | |
logits_processor = ( | |
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) | |
if logit_processor is None | |
else logit_processor | |
) | |
batch_size = len(beam_scorer._beam_hyps) | |
num_beams = beam_scorer.num_beams | |
num_beam_groups = beam_scorer.num_beam_groups | |
num_sub_beams = num_beams // num_beam_groups | |
batch_beam_size, cur_len = input_ids.shape | |
beam_indices = None | |
if num_beams * batch_size != batch_beam_size: | |
raise ValueError( | |
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
) | |
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) | |
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in | |
# the same group don't produce same tokens everytime. | |
beam_scores[:, ::num_sub_beams] = 0 | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
while True: | |
# predicted tokens in cur_len step | |
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) | |
# indices which will form the beams in the next time step | |
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) | |
# do one decoder step on all beams of all sentences in batch | |
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) | |
outputs = self( | |
model_inputs['images'], | |
model_inputs['text'], | |
embed_cls=False, | |
image_latent=image_latent, | |
image_embs=image_embs | |
) | |
for beam_group_idx in range(num_beam_groups): | |
group_start_idx = beam_group_idx * num_sub_beams | |
group_end_idx = min(group_start_idx + num_sub_beams, num_beams) | |
group_size = group_end_idx - group_start_idx | |
# indices of beams of current group among all sentences in batch | |
batch_group_indices = [] | |
for batch_idx in range(batch_size): | |
batch_group_indices.extend( | |
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] | |
) | |
group_input_ids = input_ids[batch_group_indices] | |
# select outputs of beams of currentg group only | |
next_token_logits = outputs['logits'][batch_group_indices, -1, :] | |
vocab_size = next_token_logits.shape[-1] | |
next_token_scores_processed = logits_processor( | |
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx | |
) | |
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) | |
next_token_scores = next_token_scores.expand_as(next_token_scores_processed) | |
# reshape for beam search | |
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) | |
next_token_scores, next_tokens = torch.topk( | |
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True | |
) | |
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None | |
beam_outputs = beam_scorer.process( | |
group_input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
beam_indices=process_beam_indices, | |
) | |
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids[batch_group_indices] = group_input_ids[beam_idx] | |
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
current_tokens[batch_group_indices] = group_input_ids[:, -1] | |
# (beam_idx // group_size) -> batch_idx | |
# (beam_idx % group_size) -> offset of idx inside the group | |
reordering_indices[batch_group_indices] = ( | |
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) | |
) | |
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) | |
# increase cur_len | |
cur_len = cur_len + 1 | |
if beam_scorer.is_done or stopping_criteria(input_ids, None): | |
break | |
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None | |
sequence_outputs = beam_scorer.finalize( | |
input_ids, | |
beam_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
max_length=stopping_criteria.max_length, | |
beam_indices=final_beam_indices, | |
) | |
return sequence_outputs['sequences'] | |
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): | |
if past: | |
input_ids = input_ids[:, -1].unsqueeze(-1) | |
attention_mask = kwargs.get("attention_mask", None) | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
else: | |
position_ids = None | |
return { | |
"text": input_ids, | |
"images": image_inputs, | |
"past_key_values": past, | |
"position_ids": position_ids, | |
"attention_mask": attention_mask, | |
} | |