File size: 4,627 Bytes
b0b3b00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import importlib
import math
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from transformers import GenerationConfig, PreTrainedTokenizer, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import logging
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
from .configuration_infimm_hd import InfiMMHDConfig
from .eva_vit_model import CLIPVisionCfg, EVAVisionTransformer
from .flamingo import Flamingo
from .flamingo_lm import FlamingoLMMixin
from .utils import _infer_decoder_layers_attr_name, extend_instance
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
class InfiMMPreTrainedModel(PreTrainedModel):
config_class = InfiMMHDConfig
base_model_prefix = "transformer"
is_parallelizable = False
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
class InfiMMHDModel(InfiMMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.vision_config = config.visual
vision_encoder = self.build_vision_encoder()
self.language_config = config.language
language_encoder = self.build_language_encoder()
self.model = self.build_flamingo(vision_encoder, language_encoder)
def build_vision_encoder(self, image_size=448):
vision_cfg = CLIPVisionCfg(**self.vision_config)
if image_size:
vision_cfg.image_size = image_size
vision_encoder = EVAVisionTransformer(
img_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
num_classes=vision_cfg.embed_dim,
use_mean_pooling=vision_cfg.global_average_pool, # False
init_values=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
embed_dim=vision_cfg.width,
depth=vision_cfg.layers,
num_heads=vision_cfg.width // vision_cfg.head_width,
mlp_ratio=vision_cfg.mlp_ratio,
qkv_bias=vision_cfg.qkv_bias,
drop_path_rate=vision_cfg.drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
xattn=vision_cfg.xattn,
rope=vision_cfg.rope,
postnorm=vision_cfg.postnorm,
pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
intp_freq=vision_cfg.intp_freq,
naiveswiglu=vision_cfg.naiveswiglu,
subln=vision_cfg.subln,
)
return vision_encoder
def build_language_encoder(self):
lang_encoder = AutoModelForCausalLM.from_pretrained(
self.language_config["_name_or_path"]
)
lang_encoder.resize_token_embeddings(self.language_config["vocab_size"])
return lang_encoder
def build_flamingo(self, vision_encoder, lang_encoder):
extend_instance(lang_encoder, FlamingoLMMixin)
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
model = Flamingo(
vision_encoder,
lang_encoder,
self.config.eoc_token_id,
self.config.image_token_id,
vis_dim=self.vision_config["width"],
cross_attn_every_n_layers=self.config.cross_attn_every_n_layers,
gradient_checkpointing=self.config.use_grad_checkpoint,
)
return model
def generate(
self,
batch_images,
input_ids,
attention_mask,
**kwargs,
):
with torch.inference_mode():
outputs = self.model.generate(
batch_images,
input_ids,
attention_mask,
**kwargs,
)
# Extract only the new gnerated tokens
outputs = outputs[:, len(input_ids[0]) :]
return outputs
|