lxysl's picture
upload vita-1.5 app.py
bc752b1
raw
history blame
3.2 kB
import os
import yaml
import torch
from transformers.utils.hub import get_file_from_repo
from .clip.clip_encoder import CLIPVisionTower
from .eva_clip.eva_clip_encoder import EvaClipVisionTower
from .internvit.internvit_encoder import InternViTVisionTower
from .siglip.siglip_encoder import SiglipVisionTower, SiglipVisionTowerS2
from .whale.init_model import init_model
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(
vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)
)
use_s2 = getattr(vision_tower_cfg, "use_s2", False)
if "sig" in vision_tower.lower():
if use_s2:
return SiglipVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
else:
return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "eva" in vision_tower.lower():
if use_s2:
raise ValueError(f"Currently not supporting S2 for EVA-CLIP")
else:
return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "clip" in vision_tower.lower():
if use_s2:
raise ValueError(f"Currently not supporting S2 for CLIP")
else:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "internvit" in vision_tower.lower():
if use_s2:
raise ValueError(f"Currently not supporting S2 for InternViT")
else:
return InternViTVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
else:
raise ValueError(f"Unknown vision tower: {vision_tower}")
def build_audio_encoder(audio_encoder_config, **kwargs):
with open(get_file_from_repo(audio_encoder_config.mm_audio_encoder, "train.yaml"), "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
configs["cmvn_file"] = get_file_from_repo(audio_encoder_config.mm_audio_encoder, "global_cmvn")
configs["model_conf"]["freeze_encoder"] = getattr(
audio_encoder_config, "freeze_audio_encoder", True
)
configs["model_conf"]["freeze_adpter"] = getattr(
audio_encoder_config, "freeze_audio_encoder_adapter", True
)
configs["model_conf"]["audio_prompt_finetune"] = getattr(
audio_encoder_config, "audio_prompt_finetune", False
)
configs["model_conf"]["audio_prompt_num"] = getattr(
audio_encoder_config, "audio_prompt_num", 0
)
audio_encoder = init_model(configs)
checkpoint = torch.load(get_file_from_repo(audio_encoder_config.mm_audio_encoder, "final.pt"), map_location="cpu")
model_dict = audio_encoder.state_dict()
for key in model_dict.keys():
if key in checkpoint.keys():
if model_dict[key].shape == checkpoint[key].shape:
model_dict[key] = checkpoint[key]
else:
print(
"Key {} has different shape, {} VS {}".format(
key, model_dict[key].shape, checkpoint[key].shape
)
)
else:
print("Key {} has not in resume model".format(key))
audio_encoder.load_state_dict(model_dict)
return audio_encoder