|
from typing import Dict, List, Optional, Tuple, Union, Literal |
|
from dataclasses import dataclass |
|
|
|
import json |
|
import math |
|
import logging |
|
|
|
import numpy as np |
|
from tqdm import tqdm |
|
from threading import Thread |
|
from PIL import Image |
|
import soundfile as sf |
|
from copy import deepcopy |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.nn.utils.parametrize as P |
|
from torch.nn.utils.parametrizations import weight_norm |
|
from torch.nn.utils.rnn import pad_sequence |
|
from vector_quantize_pytorch import GroupedResidualFSQ |
|
from vocos import Vocos |
|
from vocos.pretrained import instantiate_class |
|
|
|
from transformers import AutoProcessor, TextIteratorStreamer, PreTrainedModel, LogitsWarper, BertTokenizerFast, \ |
|
TopPLogitsWarper, TopKLogitsWarper, Qwen2PreTrainedModel, Qwen2ForCausalLM |
|
from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPast |
|
from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperConfig, WHISPER_ATTENTION_CLASSES, ACT2FN |
|
from transformers.cache_utils import EncoderDecoderCache, DynamicCache |
|
from transformers import LlamaConfig, LlamaModel |
|
|
|
from .configuration_minicpm import MiniCPMOConfig, ConditionalChatTTSConfig |
|
from .modeling_navit_siglip import SiglipVisionTransformer |
|
from .resampler import Resampler |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
padding_logged = False |
|
|
|
|
|
class MiniCPMOPreTrainedModel(Qwen2PreTrainedModel): |
|
config_class = MiniCPMOConfig |
|
|
|
|
|
class MiniCPMO(MiniCPMOPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.llm = Qwen2ForCausalLM(config) |
|
self.vpm = self.init_vision_module() |
|
self.apm = self.init_audio_module() |
|
self.tts = self.init_tts_module() |
|
self.vision_dim = self.vpm.embed_dim |
|
self.embed_dim = self.llm.config.hidden_size |
|
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) |
|
|
|
audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4) |
|
embed_dim = self.llm.config.hidden_size |
|
|
|
self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step) |
|
self.audio_projection_layer = MultiModalProjector( |
|
in_dim=audio_output_dim, |
|
out_dim=embed_dim |
|
) |
|
|
|
self.audio_encoder_layer = -1 |
|
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) |
|
|
|
self.terminators = ['<|im_end|>', '</s>'] |
|
|
|
self.default_tts_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" |
|
|
|
|
|
tts_text_tokenizer = BertTokenizerFast.from_pretrained("/mnt/data/user/tc_agi/xubokai/ChatTTS/asset/tokenizer") |
|
from .processing_minicpmo import ChatTTSProcessor |
|
self.tts_processor = ChatTTSProcessor(text_tokenizer=tts_text_tokenizer) |
|
|
|
|
|
self.vocos = None |
|
|
|
self.streaming_text_chunk_size = 11 |
|
self.force_no_stop=False |
|
self._generate = self.generate |
|
|
|
def initialize_vocos(self): |
|
feature_extractor = instantiate_class( |
|
args=(), init={'class_path': 'vocos.feature_extractors.MelSpectrogramFeatures', |
|
'init_args': {'sample_rate': 24000, 'n_fft': 1024, 'hop_length': 256, 'n_mels': 100}} |
|
) |
|
backbone = instantiate_class( |
|
args=(), init={'class_path': 'vocos.models.VocosBackbone', |
|
'init_args': {'input_channels': 100, 'dim': 512, 'intermediate_dim': 1536, |
|
'num_layers': 8}} |
|
) |
|
head = instantiate_class( |
|
args=(), init={'class_path': 'vocos.heads.ISTFTHead', |
|
'init_args': {'dim': 512, 'n_fft': 1024, 'hop_length': 256}} |
|
) |
|
vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32) |
|
vocos.load_state_dict( |
|
torch.load('/mnt/data/user/tc_agi/xubokai/ChatTTS/asset/Vocos.pt', weights_only=True, mmap=True)) |
|
return vocos |
|
|
|
def init_vision_module(self): |
|
|
|
if self.config._attn_implementation == 'flash_attention_2': |
|
self.config.vision_config._attn_implementation = 'flash_attention_2' |
|
else: |
|
|
|
self.config.vision_config._attn_implementation = 'eager' |
|
model = SiglipVisionTransformer(self.config.vision_config) |
|
if self.config.drop_vision_last_layer: |
|
model.encoder.layers = model.encoder.layers[:-1] |
|
|
|
setattr(model, 'embed_dim', model.embeddings.embed_dim) |
|
setattr(model, 'patch_size', model.embeddings.patch_size) |
|
|
|
return model |
|
|
|
def init_resampler(self, embed_dim, vision_dim): |
|
return Resampler( |
|
num_queries=self.config.query_num, |
|
embed_dim=embed_dim, |
|
num_heads=embed_dim // 128, |
|
kv_dim=vision_dim, |
|
adaptive=True |
|
) |
|
|
|
|
|
def init_audio_module(self): |
|
model = MiniCPMWhisperEncoder(self.config.audio_config) |
|
return model |
|
|
|
|
|
def init_tts_module(self): |
|
model = ConditionalChatTTS(self.config.tts_config) |
|
return model |
|
|
|
def get_input_embeddings(self): |
|
return self.llm.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.llm.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.llm.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.llm.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.llm = decoder |
|
|
|
def get_decoder(self): |
|
return self.llm |
|
|
|
def subsequent_chunk_mask( |
|
self, |
|
size: int, |
|
chunk_size: int, |
|
num_left_chunks: int = -1, |
|
device: torch.device = torch.device("cpu"), |
|
num_lookhead: int = 0 |
|
) -> torch.Tensor: |
|
"""Create mask for subsequent steps (size, size) with chunk size, |
|
this is for streaming encoder |
|
|
|
Args: |
|
size (int): size of mask |
|
chunk_size (int): size of chunk |
|
num_left_chunks (int): number of left chunks |
|
<0: use full chunk |
|
>=0: use num_left_chunks |
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device |
|
|
|
Returns: |
|
torch.Tensor: mask |
|
|
|
Examples: |
|
>>> subsequent_chunk_mask(4, 2) |
|
[[1, 1, 0, 0], |
|
[1, 1, 0, 0], |
|
[1, 1, 1, 1], |
|
[1, 1, 1, 1]] |
|
""" |
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool) |
|
for i in range(size): |
|
if num_left_chunks < 0: |
|
start = 0 |
|
else: |
|
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) |
|
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size) |
|
ret[i, start:ending] = True |
|
return ret |
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): |
|
""" |
|
Computes the output length of the convolutional layers and the output length of the audio encoder |
|
""" |
|
input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 |
|
input_lengths_after_pooling = (input_lengths_after_cnn - self.config.audio_pool_step) // self.config.audio_pool_step + 1 |
|
input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) |
|
|
|
return input_lengths_after_cnn, input_lengths_after_pooling |
|
|
|
def get_vllm_embedding(self, data): |
|
if 'vision_hidden_states' not in data: |
|
dtype = self.llm.model.embed_tokens.weight.dtype |
|
device = self.llm.model.embed_tokens.weight.device |
|
tgt_sizes = data['tgt_sizes'] |
|
pixel_values_list = data['pixel_values'] |
|
vision_hidden_states = [] |
|
all_pixel_values = [] |
|
img_cnt = [] |
|
for pixel_values in pixel_values_list: |
|
img_cnt.append(len(pixel_values)) |
|
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) |
|
|
|
|
|
if all_pixel_values: |
|
tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)] |
|
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) |
|
|
|
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) |
|
|
|
all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, |
|
padding_value=0.0) |
|
B, L, _ = all_pixel_values.shape |
|
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) |
|
|
|
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) |
|
for i in range(B): |
|
patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True |
|
|
|
vision_batch_size = self.config.vision_batch_size |
|
all_pixel_values = all_pixel_values.type(dtype) |
|
if B > vision_batch_size: |
|
hs = [] |
|
for i in range(0, B, vision_batch_size): |
|
start_idx = i |
|
end_idx = i + vision_batch_size |
|
tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state |
|
hs.append(tmp_hs) |
|
vision_embedding = torch.cat(hs, dim=0) |
|
else: |
|
vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state |
|
vision_embedding = self.resampler(vision_embedding, tgt_sizes) |
|
|
|
start = 0 |
|
for pixel_values in pixel_values_list: |
|
img_cnt = len(pixel_values) |
|
if img_cnt > 0: |
|
vision_hidden_states.append(vision_embedding[start: start + img_cnt]) |
|
start += img_cnt |
|
else: |
|
vision_hidden_states.append([]) |
|
else: |
|
if self.training: |
|
dummy_image = torch.zeros( |
|
(1, 3, 224, 224), |
|
device=device, dtype=dtype |
|
) |
|
tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32) |
|
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) |
|
else: |
|
dummy_feature = [] |
|
for _ in range(len(pixel_values_list)): |
|
vision_hidden_states.append(dummy_feature) |
|
|
|
else: |
|
vision_hidden_states = data['vision_hidden_states'] |
|
|
|
if hasattr(self.llm.config, 'scale_emb'): |
|
vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb |
|
else: |
|
vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) |
|
|
|
vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance( |
|
i, torch.Tensor) else i for i in vision_hidden_states] |
|
|
|
bs = len(data['input_ids']) |
|
for i in range(bs): |
|
cur_vs_hs = vision_hidden_states[i] |
|
if len(cur_vs_hs) > 0: |
|
cur_vllm_emb = vllm_embedding[i] |
|
cur_image_bound = data['image_bound'][i] |
|
if len(cur_image_bound) > 0: |
|
image_indices = torch.stack( |
|
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] |
|
).to(vllm_embedding.device) |
|
|
|
cur_vllm_emb = cur_vllm_emb.scatter(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), |
|
cur_vs_hs.view(-1, cur_vs_hs.shape[-1])) |
|
|
|
|
|
elif self.training: |
|
cur_vllm_emb += cur_vs_hs[0].mean() * 0 |
|
|
|
return vllm_embedding, vision_hidden_states |
|
|
|
def get_audio_embedding(self, data, chunk_length=-1, dummy=True): |
|
dtype = self.apm.embed_positions.weight.dtype |
|
device = self.apm.embed_positions.weight.device |
|
|
|
wavforms = data.get('audio_features', []) |
|
audio_feature_lens_raw = data.get('audio_feature_lens', []) |
|
|
|
|
|
if len(wavforms) > 0: |
|
audio_feature_lens = torch.hstack(audio_feature_lens_raw) |
|
batch_size, _, max_mel_seq_len = wavforms.shape |
|
max_seq_len = (max_mel_seq_len - 1) // 2 + 1 |
|
|
|
|
|
seq_range = ( |
|
torch.arange(0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device) |
|
.unsqueeze(0) |
|
.expand(batch_size, max_seq_len) |
|
) |
|
lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) |
|
|
|
padding_mask = seq_range >= lengths_expand |
|
|
|
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( |
|
batch_size, 1, max_seq_len, max_seq_len |
|
) |
|
audio_attention_mask = audio_attention_mask_.to( |
|
dtype=self.apm.conv1.weight.dtype, |
|
device=self.apm.conv1.weight.device |
|
) |
|
|
|
if chunk_length > 0: |
|
chunk_num_frame = int(chunk_length * 50) |
|
chunk_mask = self.subsequent_chunk_mask( |
|
size=max_seq_len, |
|
chunk_size=chunk_num_frame, |
|
num_left_chunks=-1, |
|
device=audio_attention_mask_.device |
|
) |
|
audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask)) |
|
|
|
audio_attention_mask[audio_attention_mask_] = float("-inf") |
|
audio_states = self.apm( |
|
wavforms, |
|
output_hidden_states=True, |
|
attention_mask=audio_attention_mask).hidden_states[self.audio_encoder_layer] |
|
|
|
audio_embeds = self.audio_projection_layer(audio_states) |
|
|
|
audio_embeds = audio_embeds.transpose(1, 2) |
|
audio_embeds = self.audio_avg_pooler(audio_embeds) |
|
audio_embeds = audio_embeds.transpose(1, 2) |
|
|
|
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) |
|
|
|
num_audio_tokens = feature_lens_after_pooling |
|
|
|
final_audio_embeds = [] |
|
idx = 0 |
|
for i in range(len(audio_feature_lens_raw)): |
|
target_audio_embeds = [] |
|
for _ in range(len(audio_feature_lens_raw[i])): |
|
target_audio_embeds.append(audio_embeds[idx, :num_audio_tokens[idx], :]) |
|
idx += 1 |
|
final_audio_embeds.append(target_audio_embeds) |
|
return final_audio_embeds |
|
elif self.training and dummy: |
|
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype) |
|
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer] |
|
|
|
audio_embeds = self.audio_projection_layer(audio_states) |
|
|
|
audio_embeds = audio_embeds.transpose(1, 2) |
|
audio_embeds = self.audio_avg_pooler(audio_embeds) |
|
audio_embeds = audio_embeds.transpose(1, 2) |
|
return [audio_embeds] |
|
|
|
else: |
|
return [] |
|
|
|
def get_omni_embedding(self, data, input_embeddings, chunk_length=-1): |
|
audio_embeddings = self.get_audio_embedding(data, chunk_length) |
|
|
|
bs = len(input_embeddings) |
|
if len(data.get('audio_features', [])) > 0: |
|
assert len(audio_embeddings) == len(input_embeddings) |
|
if len(audio_embeddings) > 0: |
|
audio_bounds = data['audio_bounds'] |
|
|
|
if self.config.stream_input: |
|
for i in range(bs): |
|
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(device=input_embeddings.device, |
|
dtype=input_embeddings.dtype) |
|
audio_start_pos = 0 |
|
for bound in audio_bounds[i]: |
|
audio_len = bound[1] - bound[0] |
|
input_embeddings[0, bound[0]:bound[1]] = audio_embs[ |
|
audio_start_pos:audio_start_pos + audio_len, :] |
|
audio_start_pos += audio_len |
|
else: |
|
for i in range(bs): |
|
audio_embs = audio_embeddings[i] |
|
bounds = audio_bounds[i] |
|
for embs, bound in zip(audio_embs, bounds): |
|
audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to( |
|
input_embeddings.device) |
|
|
|
if embs.shape[0] != len(audio_indices): |
|
print(f"Sample {i}:") |
|
print(f" Bounds: {bound}, Indices Length: {len(audio_indices)}") |
|
print(f" Embeddings Shape: {embs.shape}") |
|
print(f" Input Embedding Shape at Indices: {input_embeddings[i, audio_indices].shape}") |
|
raise ValueError( |
|
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " |
|
f"to input indices of length {len(audio_indices)}" |
|
) |
|
input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype) |
|
elif self.training: |
|
for i in range(bs): |
|
|
|
input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0 |
|
|
|
return input_embeddings |
|
|
|
def forward(self, data, **kwargs): |
|
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data) |
|
vllm_embedding = self.get_omni_embedding(data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length) |
|
|
|
position_ids = data["position_ids"] |
|
if position_ids.dtype != torch.int64: |
|
position_ids = position_ids.long() |
|
|
|
for key in ['input_ids', 'inputs_embeds', 'position_ids']: |
|
if key in kwargs: |
|
del kwargs[key] |
|
|
|
return self.llm( |
|
input_ids=None, |
|
position_ids=position_ids, |
|
inputs_embeds=vllm_embedding, |
|
**kwargs |
|
) |
|
|
|
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs): |
|
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
|
outputs = self.llm.generate( |
|
inputs_embeds=inputs_embeds, |
|
pad_token_id=0, |
|
eos_token_id=terminators, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
return_dict_in_generate=True, |
|
**kwargs |
|
) |
|
return outputs |
|
|
|
def _decode_stream(self, inputs_embeds, tokenizer, **kwargs): |
|
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
|
streamer = TextIteratorStreamer(tokenizer=tokenizer) |
|
generation_kwargs = { |
|
'inputs_embeds': inputs_embeds, |
|
'pad_token_id': 0, |
|
'eos_token_id': terminators, |
|
'streamer': streamer |
|
} |
|
generation_kwargs.update(kwargs) |
|
|
|
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
return streamer |
|
|
|
def _decode_text(self, result_ids, tokenizer): |
|
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
|
result_text = [] |
|
for result in result_ids: |
|
result = result[result != 0] |
|
if result[0] == tokenizer.bos_id: |
|
result = result[1:] |
|
if result[-1] in terminators: |
|
result = result[:-1] |
|
result_text.append(tokenizer.decode(result).strip()) |
|
return result_text |
|
|
|
def generate( |
|
self, |
|
input_ids=None, |
|
pixel_values=None, |
|
tgt_sizes=None, |
|
audio_features=[], |
|
audio_feature_lens=None, |
|
image_bound=None, |
|
audio_bounds=None, |
|
spk_bounds=None, |
|
attention_mask=None, |
|
tokenizer=None, |
|
vision_hidden_states=None, |
|
stream=False, |
|
**kwargs |
|
): |
|
assert input_ids is not None |
|
assert len(input_ids) == len(pixel_values) |
|
|
|
model_inputs = { |
|
"input_ids": input_ids, |
|
"audio_features": audio_features, |
|
"audio_feature_lens": audio_feature_lens, |
|
"image_bound": image_bound, |
|
"audio_bounds": audio_bounds, |
|
"spk_bounds": spk_bounds, |
|
} |
|
|
|
if vision_hidden_states is None: |
|
model_inputs["pixel_values"] = pixel_values |
|
model_inputs['tgt_sizes'] = tgt_sizes |
|
else: |
|
model_inputs["vision_hidden_states"] = vision_hidden_states |
|
|
|
model_output = {} |
|
with torch.inference_mode(): |
|
model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs) |
|
model_inputs["inputs_embeds"] = self.get_omni_embedding( |
|
model_inputs, input_embeddings=model_inputs["inputs_embeds"], chunk_length=self.config.audio_chunk_length |
|
) |
|
|
|
if stream: |
|
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs) |
|
|
|
outputs = {} |
|
else: |
|
outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs) |
|
result = self._decode_text(outputs.sequences, tokenizer) |
|
|
|
return result, outputs |
|
|
|
def chat( |
|
self, |
|
image, |
|
msgs, |
|
tokenizer, |
|
processor=None, |
|
vision_hidden_states=None, |
|
max_new_tokens=2048, |
|
min_new_tokens=0, |
|
sampling=True, |
|
max_inp_length=8192, |
|
stream=False, |
|
stream_input=True, |
|
omni_input=False, |
|
max_slice_nums=None, |
|
use_image_id=None, |
|
use_tts=False, |
|
output_audio_path=None, |
|
return_spk_embed=False, |
|
**kwargs |
|
): |
|
if isinstance(msgs[0], list): |
|
batched = True |
|
else: |
|
batched = False |
|
msgs_list = msgs |
|
images_list = image |
|
|
|
if batched is False: |
|
images_list, msgs_list = [images_list], [msgs_list] |
|
else: |
|
assert images_list is None, "Please integrate image to msgs when using batch inference." |
|
images_list = [None] * len(msgs_list) |
|
assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same." |
|
|
|
if processor is None: |
|
if self.processor is None: |
|
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) |
|
processor = self.processor |
|
|
|
assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
|
assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
|
assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
|
assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
|
assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
|
|
|
prompts_lists = [] |
|
input_images_list = [] |
|
input_audios_list = [] |
|
audio_parts_list = [] |
|
|
|
for image, msgs in zip(images_list, msgs_list): |
|
if isinstance(msgs, str): |
|
msgs = json.loads(msgs) |
|
copy_msgs = deepcopy(msgs) |
|
|
|
assert len(msgs) > 0, "msgs is empty" |
|
assert sampling or not stream, "if use stream mode, make sure sampling=True" |
|
|
|
if image is not None and isinstance(copy_msgs[0]["content"], str): |
|
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]] |
|
|
|
images = [] |
|
audios = [] |
|
audio_parts = [] |
|
for i, msg in enumerate(copy_msgs): |
|
role = msg["role"] |
|
content = msg["content"] |
|
assert role in ["system", "user", "assistant"] |
|
if i == 0: |
|
assert role in ["user", "system"], "The role of first msg should be user" |
|
if isinstance(content, str): |
|
content = [content] |
|
cur_msgs = [] |
|
for c in content: |
|
if isinstance(c, Image.Image): |
|
images.append(c) |
|
cur_msgs.append("<image>./</image>") |
|
elif isinstance(c, np.ndarray): |
|
audios.append(c) |
|
audio_parts.append(i) |
|
cur_msgs.append("<audio>./</audio>") |
|
elif isinstance(c, str): |
|
cur_msgs.append(c) |
|
if omni_input: |
|
msg["content"] = "".join(cur_msgs) |
|
else: |
|
msg["content"] = "\n".join(cur_msgs) |
|
|
|
prompts_lists.append( |
|
processor.tokenizer.apply_chat_template( |
|
copy_msgs, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
chat_template=self.default_tts_chat_template if use_tts else None |
|
) |
|
) |
|
input_images_list.append(images) |
|
input_audios_list.append(audios) |
|
audio_parts_list.append(audio_parts) |
|
|
|
|
|
inputs = processor( |
|
prompts_lists, |
|
input_images_list, |
|
input_audios_list, |
|
audio_parts_list, |
|
max_slice_nums=max_slice_nums, |
|
use_image_id=use_image_id, |
|
stream_input=stream_input, |
|
return_tensors="pt", |
|
max_length=max_inp_length |
|
).to(self.device) |
|
|
|
if sampling: |
|
generation_config = { |
|
"top_p": 0.8, |
|
"top_k": 100, |
|
"temperature": 0.7, |
|
"do_sample": True, |
|
"repetition_penalty": 1.05 |
|
} |
|
else: |
|
generation_config = { |
|
"num_beams": 3, |
|
"repetition_penalty": 1.2, |
|
} |
|
|
|
if min_new_tokens > 0: |
|
generation_config['min_new_tokens'] = min_new_tokens |
|
|
|
generation_config.update( |
|
(k, kwargs[k]) for k in generation_config.keys() & kwargs.keys() |
|
) |
|
|
|
inputs.pop("image_sizes") |
|
with torch.inference_mode(): |
|
res, outputs = self.generate( |
|
**inputs, |
|
tokenizer=tokenizer, |
|
max_new_tokens=max_new_tokens, |
|
vision_hidden_states=vision_hidden_states, |
|
stream=stream, |
|
**generation_config |
|
) |
|
|
|
if stream: |
|
def stream_gen(): |
|
for text in res: |
|
for term in self.terminators: |
|
text = text.replace(term, '') |
|
yield text |
|
return stream_gen() |
|
|
|
else: |
|
if batched: |
|
answer = res |
|
else: |
|
answer = res[0] |
|
|
|
if use_tts and output_audio_path: |
|
mel_spec = self._generate_mel_spec(inputs, outputs, answer) |
|
self.decode_mel_to_audio(mel_spec, output_audio_path) |
|
|
|
if return_spk_embed: |
|
spk_embeds = self._get_last_spk_embeds(inputs, outputs) |
|
return answer, spk_embeds |
|
else: |
|
return answer |
|
|
|
def prepare_tts_text(self, text): |
|
tts_tokens = self.tts_processor.text_tokenizer.encode(text, add_special_tokens=False) |
|
tts_tokens_len = len(tts_tokens) |
|
if tts_tokens_len < self.tts.streaming_text_reserved_len: |
|
num_pad_tokens = self.tts.streaming_text_reserved_len - tts_tokens_len |
|
|
|
pad_str = "[Etts]" + "[PAD]" * (num_pad_tokens - 1) |
|
else: |
|
tts_tokens = tts_tokens[0: self.tts.streaming_text_reserved_len] |
|
tts_tokens_len = len(tts_tokens) |
|
text = self.tts_processor.text_tokenizer.decode(tts_tokens, add_special_tokens=False) |
|
pad_str = "" |
|
spk_emb_placeholder_tts = "[spk_emb]" * self.tts.num_spk_embs |
|
|
|
new_text_tts = f"[Stts]{spk_emb_placeholder_tts}{text}{pad_str}[Ptts]" |
|
return new_text_tts, tts_tokens_len |
|
|
|
def _build_streaming_mask(self, tts_tokens_len): |
|
tts_sequence_full_length = 1 + self.tts.num_spk_embs * self.tts.use_speaker_embedding + self.tts.streaming_text_reserved_len + 1 |
|
streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8) |
|
streaming_attention_mask[0: 1 + 1 + tts_tokens_len + 1] = 1 |
|
streaming_attention_mask[-1] = 1 |
|
return streaming_attention_mask |
|
|
|
def _get_last_spk_embeds(self, inputs, outputs): |
|
last_hidden_states = [hs[-1] for hs in outputs.hidden_states] |
|
|
|
|
|
last_hidden_states = torch.vstack([i[0] for i in last_hidden_states]) |
|
|
|
|
|
spk_bound = inputs['spk_bounds'][0][-1] |
|
|
|
spk_embeds = last_hidden_states[spk_bound[0]: spk_bound[1]] |
|
return spk_embeds |
|
|
|
def _generate_mel_spec(self, inputs, outputs, text): |
|
spk_embeds = self._get_last_spk_embeds(inputs, outputs) |
|
|
|
gen_text = text.replace('<|tts_eos|>', '') |
|
tts_text, tts_token_lens = self.prepare_tts_text(gen_text) |
|
tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False) |
|
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long) |
|
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) |
|
|
|
logits_warpers, logits_processors = gen_logits( |
|
num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty |
|
) |
|
|
|
condition_length = 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1 |
|
|
|
dtype = self.tts.emb_text.weight.dtype |
|
emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device) |
|
past_key_values = [ |
|
( |
|
torch.zeros(1, self.tts.config.num_attention_heads, condition_length - 1, |
|
self.tts.config.hidden_size // self.tts.config.num_attention_heads, dtype=emb.dtype, |
|
device=self.tts.device), |
|
torch.zeros(1, self.tts.config.num_attention_heads, condition_length - 1, |
|
self.tts.config.hidden_size // self.tts.config.num_attention_heads, dtype=emb.dtype, |
|
device=self.tts.device) |
|
) |
|
for _ in range(self.tts.config.num_hidden_layers) |
|
] |
|
|
|
audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device) |
|
|
|
eos_lab = False |
|
for chunk_idx in range(math.ceil(emb.shape[1] / self.streaming_text_chunk_size)): |
|
if chunk_idx == 0: |
|
begin = chunk_idx * self.streaming_text_chunk_size + 0 |
|
end = (chunk_idx + 1) * self.streaming_text_chunk_size + 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs |
|
else: |
|
begin = chunk_idx * self.streaming_text_chunk_size + 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs |
|
end = min((chunk_idx + 1) * self.streaming_text_chunk_size + 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs, |
|
condition_length - 1) |
|
if end - begin < 1: |
|
print(f"BKing has break by the end of {end} and begin of {begin}") |
|
else: |
|
text_input_ids = tts_input_ids[:, begin: end] |
|
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) |
|
|
|
if begin == 0: |
|
past_key_values = self.tts.prefill_text( |
|
input_ids=text_input_ids, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
lm_spk_emb_last_hidden_states=spk_embeds |
|
) |
|
else: |
|
past_key_values = self.tts.prefill_text( |
|
input_ids=text_input_ids, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values |
|
) |
|
|
|
outputs = self.tts.generate( |
|
input_ids=audio_input_ids, |
|
past_key_values=past_key_values, |
|
streaming_tts_text_mask=streaming_tts_text_mask, |
|
max_new_token=25, |
|
force_no_stop=self.force_no_stop, |
|
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
|
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
|
logits_warpers=logits_warpers, |
|
logits_processors=logits_processors, |
|
) |
|
|
|
audio_input_ids = outputs.audio_input_ids |
|
past_key_values = outputs.past_key_values |
|
|
|
if outputs.finished: |
|
print("Generation finished.") |
|
eos_lab = True |
|
break |
|
|
|
if not eos_lab: |
|
print("Generation not finished.") |
|
while True: |
|
outputs = self.tts.generate( |
|
input_ids=audio_input_ids, |
|
past_key_values=past_key_values, |
|
streaming_tts_text_mask=streaming_tts_text_mask, |
|
max_new_token=25, |
|
force_no_stop=self.force_no_stop, |
|
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
|
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
|
logits_warpers=logits_warpers, |
|
logits_processors=logits_processors, |
|
) |
|
|
|
audio_input_ids = outputs.audio_input_ids |
|
past_key_values = outputs.past_key_values |
|
|
|
if outputs.finished: |
|
print("Generation finished.") |
|
break |
|
if outputs.new_ids.shape[1] > 2048: |
|
print("Generation not finished but break.") |
|
break |
|
|
|
mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids) |
|
print("Mel spectrogram generated.") |
|
return mel_spec |
|
|
|
def decode_mel_to_audio(self, mel_spec, output_path="test.wav"): |
|
if self.vocos is None: |
|
self.vocos = self.initialize_vocos() |
|
|
|
with torch.inference_mode(): |
|
wav_numpy = self.vocos.decode(mel_spec.float()).cpu().numpy().squeeze() |
|
sf.write(output_path, wav_numpy, samplerate=24000) |
|
print(f"Audio saved to {output_path}.") |
|
|
|
|
|
class MiniCPMWhisperEncoderLayer(nn.Module): |
|
def __init__(self, config: WhisperConfig, layer_idx: int = None): |
|
super().__init__() |
|
self.embed_dim = config.d_model |
|
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( |
|
embed_dim=self.embed_dim, |
|
num_heads=config.encoder_attention_heads, |
|
dropout=config.attention_dropout, |
|
config=config, |
|
layer_idx=layer_idx |
|
) |
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
self.dropout = config.dropout |
|
self.activation_fn = ACT2FN[config.activation_function] |
|
self.activation_dropout = config.activation_dropout |
|
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) |
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) |
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
layer_head_mask: torch.Tensor, |
|
output_attentions: bool = False, |
|
past_key_values: Optional[EncoderDecoderCache] = None, |
|
use_cache: Optional[bool] = False, |
|
) -> torch.Tensor: |
|
residual = hidden_states |
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
hidden_states, attn_weights, past_key_values = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
layer_head_mask=layer_head_mask, |
|
output_attentions=output_attentions, |
|
past_key_value=past_key_values |
|
) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
hidden_states = residual + hidden_states |
|
|
|
if hidden_states.dtype == torch.float16 and ( |
|
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
|
): |
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (past_key_values,) |
|
|
|
return outputs |
|
|
|
class MiniCPMWhisperEncoder(WhisperEncoder): |
|
|
|
def __init__(self, config: WhisperConfig): |
|
super().__init__(config) |
|
self.layers = nn.ModuleList([ |
|
MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers) |
|
]) |
|
|
|
def forward( |
|
self, |
|
input_features, |
|
attention_mask=None, |
|
head_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
past_key_values: Optional[EncoderDecoderCache] = None, |
|
use_cache: Optional[bool] = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) |
|
|
|
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) |
|
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) |
|
|
|
inputs_embeds = inputs_embeds.permute(0, 2, 1) |
|
|
|
embed_pos = self.embed_positions.weight |
|
past_key_values_length = 0 |
|
if use_cache: |
|
if past_key_values is None: |
|
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
|
elif isinstance(past_key_values, list): |
|
past_key_values = EncoderDecoderCache( |
|
DynamicCache.from_legacy_cache(past_key_values), DynamicCache()) |
|
elif isinstance(past_key_values, DynamicCache): |
|
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
|
else: |
|
pass |
|
past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1]) |
|
if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]: |
|
if not padding_logged: |
|
padding_logged = True |
|
logger.warning("seems the audio is longer than 30s. repeating the last part of the audio") |
|
embed_pos_front = embed_pos[past_key_values_length:, :] |
|
embed_pos = torch.cat(( |
|
embed_pos_front, |
|
torch.repeat_interleave( |
|
embed_pos[-1, :].unsqueeze(0), |
|
inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, |
|
dim=0 |
|
) |
|
)) |
|
else: |
|
embed_pos = embed_pos[past_key_values_length:inputs_embeds.shape[1] + past_key_values_length, :] |
|
else: |
|
embed_pos = embed_pos[:inputs_embeds.shape[1], :] |
|
|
|
hidden_states = inputs_embeds + embed_pos |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
encoder_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
|
|
|
|
if head_mask is not None: |
|
assert head_mask.size()[0] == ( |
|
len(self.layers) |
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." |
|
|
|
for idx, encoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
to_drop = False |
|
if self.training: |
|
dropout_probability = torch.rand([]) |
|
if dropout_probability < self.layerdrop: |
|
to_drop = True |
|
|
|
|
|
if to_drop: |
|
layer_outputs = (None, None) |
|
else: |
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
encoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
(head_mask[idx] if head_mask is not None else None), |
|
output_attentions, |
|
past_key_values, |
|
use_cache |
|
) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
attention_mask, |
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
output_attentions=output_attentions, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_encoder_cache = layer_outputs[2 if output_attentions else 1] |
|
else: |
|
next_encoder_cache = None |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
hidden_states=encoder_states, |
|
attentions=all_attentions, |
|
past_key_values=next_encoder_cache |
|
) |
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
intermediate_dim: int, |
|
kernel: int, |
|
dilation: int, |
|
layer_scale_init_value: float = 1e-6, |
|
): |
|
|
|
super().__init__() |
|
self.dwconv = nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size=kernel, |
|
padding=dilation * (kernel // 2), |
|
dilation=dilation, |
|
groups=dim, |
|
) |
|
|
|
self.norm = nn.LayerNorm(dim, eps=1e-6) |
|
self.pwconv1 = nn.Linear( |
|
dim, intermediate_dim |
|
) |
|
self.act = nn.GELU() |
|
self.pwconv2 = nn.Linear(intermediate_dim, dim) |
|
self.coef = ( |
|
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) |
|
if layer_scale_init_value > 0 |
|
else None |
|
) |
|
|
|
def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: |
|
residual = x |
|
|
|
y = self.dwconv(x) |
|
y.transpose_(1, 2) |
|
x = self.norm(y) |
|
del y |
|
y = self.pwconv1(x) |
|
del x |
|
x = self.act(y) |
|
del y |
|
y = self.pwconv2(x) |
|
del x |
|
if self.coef is not None: |
|
y *= self.coef |
|
y.transpose_(1, 2) |
|
|
|
x = y + residual |
|
del y |
|
|
|
return x |
|
|
|
|
|
class GFSQ(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
levels: List[int], |
|
G: int, |
|
R: int, |
|
eps=1e-5, |
|
transpose=True, |
|
): |
|
super(GFSQ, self).__init__() |
|
self.quantizer = GroupedResidualFSQ( |
|
dim=dim, |
|
levels=list(levels), |
|
num_quantizers=R, |
|
groups=G, |
|
) |
|
self.n_ind = math.prod(levels) |
|
self.eps = eps |
|
self.transpose = transpose |
|
self.G = G |
|
self.R = R |
|
|
|
def _embed(self, x: torch.Tensor): |
|
if self.transpose: |
|
x = x.transpose(1, 2) |
|
x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3) |
|
feat = self.quantizer.get_output_from_indices(x) |
|
return feat.transpose_(1, 2) if self.transpose else feat |
|
|
|
def __call__(self, x: torch.Tensor) -> torch.Tensor: |
|
return super().__call__(x) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.transpose: |
|
x.transpose_(1, 2) |
|
_, ind = self.quantizer(x) |
|
ind = ind.permute(1, 2, 0, 3).contiguous() |
|
ind = ind.view(ind.size(0), ind.size(1), -1) |
|
return ind.transpose_(1, 2) if self.transpose else ind |
|
|
|
|
|
class DVAEDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
idim: int, |
|
odim: int, |
|
n_layer=12, |
|
bn_dim=64, |
|
hidden=256, |
|
kernel=7, |
|
dilation=2, |
|
up=False, |
|
): |
|
super().__init__() |
|
self.up = up |
|
self.conv_in = nn.Sequential( |
|
nn.Conv1d(idim, bn_dim, 3, 1, 1), |
|
nn.GELU(), |
|
nn.Conv1d(bn_dim, hidden, 3, 1, 1), |
|
) |
|
self.decoder_block = nn.ModuleList( |
|
[ |
|
ConvNeXtBlock( |
|
hidden, |
|
hidden * 4, |
|
kernel, |
|
dilation, |
|
) |
|
for _ in range(n_layer) |
|
] |
|
) |
|
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) |
|
|
|
def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: |
|
|
|
y = self.conv_in(x) |
|
del x |
|
for f in self.decoder_block: |
|
y = f(y, conditioning) |
|
|
|
x = self.conv_out(y) |
|
del y |
|
return x |
|
|
|
|
|
class DVAE(nn.Module): |
|
def __init__( |
|
self, |
|
): |
|
super().__init__() |
|
|
|
coef = torch.rand(100) |
|
self.coef = nn.Parameter(coef.unsqueeze(0).unsqueeze_(2)) |
|
|
|
self.downsample_conv = nn.Sequential( |
|
nn.Conv1d(100, 512, 3, 1, 1), |
|
nn.GELU(), |
|
nn.Conv1d(512, 512, 4, 2, 1), |
|
nn.GELU(), |
|
) |
|
|
|
self.encoder = DVAEDecoder( |
|
idim=512, |
|
odim=1024, |
|
hidden=256, |
|
n_layer=12, |
|
bn_dim=128, |
|
) |
|
|
|
self.decoder = DVAEDecoder( |
|
idim=512, |
|
odim=512, |
|
hidden=256, |
|
n_layer=12, |
|
bn_dim=128, |
|
) |
|
|
|
self.out_conv = nn.Conv1d(512, 100, 3, 1, 1, bias=False) |
|
|
|
self.vq_layer = GFSQ( |
|
dim=1024, |
|
levels=(5, 5, 5, 5), |
|
G=2, |
|
R=2, |
|
) |
|
|
|
@torch.inference_mode() |
|
def forward( |
|
self, |
|
inp: torch.Tensor, |
|
mode: Literal["encode", "decode"] = "decode" |
|
) -> torch.Tensor: |
|
if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None: |
|
mel = inp.clone() |
|
x: torch.Tensor = self.downsample_conv( |
|
torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel), |
|
).unsqueeze_(0) |
|
del mel |
|
x = self.encoder(x) |
|
ind = self.vq_layer(x) |
|
del x |
|
return ind |
|
|
|
if self.vq_layer is not None: |
|
vq_feats = self.vq_layer._embed(inp) |
|
else: |
|
vq_feats = inp |
|
|
|
vq_feats = ( |
|
vq_feats.view( |
|
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)), |
|
) |
|
.permute(0, 2, 3, 1) |
|
.flatten(2) |
|
) |
|
|
|
dec_out = self.out_conv( |
|
self.decoder( |
|
x=vq_feats, |
|
), |
|
) |
|
|
|
del vq_feats |
|
|
|
return torch.mul(dec_out, self.coef, out=dec_out) |
|
|
|
|
|
|
|
def apply_spk_emb( |
|
input_ids: torch.Tensor = None, |
|
spk_emb: torch.Tensor = None, |
|
input_embeds: torch.Tensor = None, |
|
spk_emb_token_id: int = 0, |
|
num_spk_embs: int = 1, |
|
): |
|
""" |
|
Replace consecutive speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned. |
|
|
|
Args: |
|
input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max] |
|
spk_emb (torch.Tensor): Speaker embedding tensor, shape [batch_size, num_spk_emb, hidden_dim] |
|
input_embeds (torch.Tensor): Input embedding tensor, shape [batch_size, seq_len_max, hidden_dim] |
|
spk_emb_token_id (int): ID of the speaker embedding token |
|
num_spk_embs (int): Number of speaker embeddings |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
batch_size = input_ids.shape[0] |
|
|
|
for idx in range(batch_size): |
|
input_ids_ = input_ids[idx] |
|
spk_emb_ = spk_emb[idx] |
|
mask_ = input_ids_ == spk_emb_token_id |
|
nonzero_position_idx = mask_.nonzero(as_tuple=False) |
|
assert nonzero_position_idx.shape[0] == num_spk_embs |
|
begin_idx = nonzero_position_idx.min() |
|
end_idx = nonzero_position_idx.max() |
|
input_embeds[idx, begin_idx: end_idx + 1, :] = spk_emb_ |
|
|
|
return |
|
|
|
|
|
def make_streaming_chunk_mask( |
|
input_embeds: torch.Tensor, |
|
tts_text_scopes: List[List[int]], |
|
tts_audio_scopes: List[List[int]], |
|
tts_text_masks: List[torch.Tensor], |
|
min_chunk_num_token: int = 5, |
|
max_chunk_num_token: int = 7, |
|
streaming_audio_chunk_size: int = 50, |
|
): |
|
""" |
|
Create a look-ahead chunked attention mask that allows the TTS transformer to see only the first M tokens when generating each N to N+1 seconds of audio, enabling streaming TTS. |
|
|
|
Args: |
|
input_embeds (torch.Tensor): Input embeddings combining text and audio, shape [batch_size, seq_len, hidden_dim] |
|
tts_text_scopes (List[List[int]]): Range of text tokens for each sample |
|
tts_audio_scopes (List[List[int]]): Range of audio tokens for each sample |
|
tts_text_masks (List[torch.Tensor]): Text masks for each sample |
|
min_chunk_num_token (int): Minimum number of new text tokens the model can see per audio chunk |
|
max_chunk_num_token (int): Maximum number of new text tokens the model can see per audio chunk |
|
streaming_audio_chunk_size (int): Size of audio chunk, 50 corresponds to approximately 1 second of audio |
|
|
|
Returns: |
|
torch.Tensor: 4D causal mask with shape [batch_size, 1, seq_len, seq_len] |
|
|
|
Example: |
|
Input sequence: |
|
[t1, t2, t3, t4, t5, [Ptts], a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, ...] |
|
Output 4D causal mask: |
|
------- text positions ------- |
|
[0] <- here is [Stts] |
|
[0, 0] <- here is [spk_emb] * N |
|
[0, 0, 0] |
|
[0, 0, 0, 0] |
|
[0, 0, 0, 0, 0] |
|
------- audio positions -------- |
|
[0, 0, -inf, -inf, -inf, 0] <- here is [Ptts], [Ptts]'s last hidden state should predict the first audio token |
|
v- here is [Ptts] |
|
[0, 0, -inf, -inf, -inf, 0, 0] |
|
[0, 0, -inf, -inf, -inf, 0, 0, 0] |
|
[0, 0, -inf, -inf, -inf, 0, 0, 0, 0] |
|
[0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0] |
|
[0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0, 0] # end of first 1s audio chunk |
|
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0] |
|
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0] |
|
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
|
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
|
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
|
""" |
|
|
|
|
|
batch_size = input_embeds.shape[0] |
|
input_embeds_attention_mask = torch.ones(input_embeds.shape[0], input_embeds.shape[1], dtype=torch.int8, |
|
device=input_embeds.device) |
|
|
|
for idx in range(batch_size): |
|
input_embeds_attention_mask[idx, tts_text_scopes[idx][0]: tts_text_scopes[idx][1]] = tts_text_masks[idx] |
|
|
|
|
|
dtype = input_embeds.dtype |
|
device = input_embeds.device |
|
min_dtype = torch.finfo(dtype).min |
|
sequence_length = input_embeds.shape[1] |
|
causal_mask = torch.full((sequence_length, sequence_length), fill_value=min_dtype, dtype=dtype, device=device) |
|
if sequence_length != 1: |
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
else: |
|
raise ValueError("sequence_length of tts could not be 1.") |
|
causal_mask = causal_mask.unsqueeze(0).repeat(input_embeds.shape[0], 1, 1) |
|
|
|
|
|
for idx in range(input_embeds.shape[0]): |
|
tts_audio_scope = tts_audio_scopes[idx] |
|
tts_text_scope = tts_text_scopes[idx] |
|
|
|
audio_token_start = tts_audio_scope[0] |
|
audio_duration = tts_audio_scope[1] - tts_audio_scope[0] |
|
|
|
|
|
text_pivot = 0 |
|
num_valid_text_tokens = torch.sum(tts_text_masks[idx]).item() - 1 |
|
|
|
num_buckets = max(1, math.floor(audio_duration / streaming_audio_chunk_size)) |
|
|
|
|
|
num_text_tokens_per_audio_chunk = math.ceil( |
|
num_valid_text_tokens / num_buckets) |
|
if num_text_tokens_per_audio_chunk > 10: |
|
num_text_tokens_per_audio_chunk = 10 |
|
elif num_text_tokens_per_audio_chunk < 4: |
|
num_text_tokens_per_audio_chunk = 4 |
|
else: |
|
pass |
|
|
|
|
|
|
|
|
|
for chunk_idx in range(math.ceil(audio_duration / streaming_audio_chunk_size)): |
|
audio_chunk_start = audio_token_start + chunk_idx * streaming_audio_chunk_size |
|
audio_chunk_end = audio_token_start + (chunk_idx + 1) * streaming_audio_chunk_size |
|
|
|
new_text_this_chunk = num_text_tokens_per_audio_chunk |
|
|
|
text_pivot = min(new_text_this_chunk + text_pivot, num_valid_text_tokens) |
|
|
|
|
|
causal_mask[ |
|
idx, |
|
audio_chunk_start - 1: audio_chunk_end - 1, |
|
tts_text_scope[0] + text_pivot: tts_text_scope[1] - 1 |
|
] = min_dtype |
|
|
|
|
|
causal_mask[idx, :, input_embeds_attention_mask[idx] == 0] = min_dtype |
|
|
|
|
|
causal_mask = causal_mask.unsqueeze(1) |
|
|
|
return causal_mask |
|
|
|
|
|
def make_streaming_chunk_mask_generation( |
|
inputs_embeds: torch.Tensor, |
|
past_seen_tokens: int, |
|
streaming_tts_text_mask: torch.Tensor, |
|
streaming_reserved_length: int = 300, |
|
streaming_audio_chunk_size: int = 50, |
|
streaming_text_chunk_size: int = 10, |
|
num_spk_emb: int = 1, |
|
use_spk_emb: bool = True, |
|
) -> torch.Tensor: |
|
""" |
|
Determine which `text` tokens the model can attend to when generating each chunk of `audio` tokens. |
|
|
|
This function creates a mask that allows the model to attend to a specific chunk of text |
|
tokens when generating each chunk of audio tokens, enabling streaming TTS generation. |
|
|
|
Args: |
|
inputs_embeds (torch.Tensor): Input embeddings tensor. |
|
past_seen_tokens (int): Number of tokens already seen by the model. |
|
streaming_tts_text_mask (torch.Tensor): Mask for the text tokens. |
|
streaming_reserved_length (int, optional): Number of reserved tokens for streaming. Defaults to 300. |
|
streaming_chunk_length (int, optional): Length of each streaming chunk. Defaults to 50. |
|
streaming_text_chunk_size (int, optional): Size of each text chunk. Defaults to 7. |
|
|
|
Returns: |
|
torch.Tensor: Causal mask for streaming TTS generation, shape is [batch_size=1, 1, seq_len=1, past_seen_tokens+1] |
|
|
|
Raises: |
|
AssertionError: If the batch size is not 1 (only supports batch size of 1 for inference). |
|
""" |
|
assert inputs_embeds.shape[0] == 1 |
|
|
|
dtype = inputs_embeds.dtype |
|
device = inputs_embeds.device |
|
min_dtype = torch.finfo(dtype).min |
|
|
|
|
|
causal_mask = torch.full((1, past_seen_tokens + 1), fill_value=0, dtype=dtype, device=device) |
|
|
|
|
|
invisible_text_tokens_start = min( |
|
math.ceil( |
|
(past_seen_tokens - streaming_reserved_length) / streaming_audio_chunk_size |
|
) * streaming_text_chunk_size, |
|
streaming_reserved_length |
|
) + 1 + num_spk_emb * use_spk_emb |
|
|
|
invisible_text_tokens_end = streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1 |
|
|
|
|
|
causal_mask[0, invisible_text_tokens_start: invisible_text_tokens_end] = min_dtype |
|
|
|
|
|
causal_mask[0, 0: 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1].masked_fill_( |
|
streaming_tts_text_mask == 0, min_dtype) |
|
|
|
|
|
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) |
|
|
|
return causal_mask |
|
|
|
|
|
class CustomRepetitionPenaltyLogitsProcessorRepeat: |
|
def __init__(self, penalty: float, max_input_ids: int, past_window: int): |
|
if not isinstance(penalty, float) or not (penalty > 0): |
|
raise ValueError( |
|
f"`penalty` has to be a strictly positive float, but is {penalty}" |
|
) |
|
|
|
self.penalty = penalty |
|
self.max_input_ids = max_input_ids |
|
self.past_window = past_window |
|
|
|
def __call__( |
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor |
|
) -> torch.FloatTensor: |
|
if input_ids.size(1) > self.past_window: |
|
input_ids = input_ids.narrow(1, -self.past_window, self.past_window) |
|
freq = F.one_hot(input_ids, scores.size(1)).sum(1) |
|
if freq.size(0) > self.max_input_ids: |
|
freq.narrow( |
|
0, self.max_input_ids, freq.size(0) - self.max_input_ids |
|
).zero_() |
|
alpha = torch.pow(self.penalty, freq) |
|
scores = scores.contiguous() |
|
inp = scores.multiply(alpha) |
|
oth = scores.divide(alpha) |
|
con = scores < 0 |
|
out = torch.where(con, inp, oth) |
|
del inp, oth, scores, con, alpha |
|
return out |
|
|
|
|
|
@dataclass |
|
class ConditionalChatTTSGenerationOutput(ModelOutput): |
|
""" |
|
Output class for ConditionalChatTTS generation. |
|
|
|
Args: |
|
new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq). |
|
audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq). |
|
past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head). |
|
finished (bool): Boolean indicating whether generation is complete. |
|
|
|
""" |
|
|
|
new_ids: torch.LongTensor = None |
|
audio_input_ids: torch.LongTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
finished: bool = None |
|
|
|
|
|
class MultiModalProjector(nn.Module): |
|
def __init__(self, in_dim, out_dim): |
|
super().__init__() |
|
self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) |
|
self.relu = nn.ReLU() |
|
self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) |
|
|
|
def forward(self, audio_features): |
|
hidden_states = self.relu(self.linear1(audio_features)) |
|
hidden_states = self.linear2(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class ConditionalChatTTS(PreTrainedModel): |
|
config_class = ConditionalChatTTSConfig |
|
_no_split_modules = [] |
|
|
|
def __init__( |
|
self, |
|
config: ConditionalChatTTSConfig |
|
): |
|
super().__init__(config) |
|
|
|
self.use_speaker_embedding = config.use_speaker_embedding |
|
self.use_llm_hidden_state = config.use_llm_hidden_state |
|
self.num_spk_embs = config.num_spk_embs |
|
self.spk_emb_token_id = config.spk_emb_token_id |
|
|
|
self.use_text = config.use_text |
|
self.streaming = config.streaming |
|
self.streaming_text_chunk_min = config.streaming_text_chunk_min |
|
self.streaming_text_chunk_max = config.streaming_text_chunk_max |
|
self.streaming_text_chunk_size = config.streaming_text_chunk_size |
|
self.streaming_audio_chunk_size = config.streaming_audio_chunk_size |
|
self.streaming_text_reserved_len = config.streaming_text_reserved_len |
|
self.audio_bos_token_id = config.audio_bos_token_id |
|
self.num_mel_bins = config.num_mel_bins |
|
self.num_vq = config.num_vq |
|
self.num_audio_tokens = config.num_audio_tokens |
|
|
|
self.top_p = config.top_p |
|
self.top_k = config.top_k |
|
self.repetition_penalty = config.repetition_penalty |
|
|
|
if self.config.use_mlp: |
|
self.projector = MultiModalProjector(config.llm_dim, config.hidden_size) |
|
else: |
|
self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False) |
|
self.emb_code = nn.ModuleList( |
|
[ |
|
nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq) |
|
] |
|
) |
|
self.emb_text = nn.Embedding( |
|
config.num_text_tokens, config.hidden_size |
|
) |
|
self.head_code = nn.ModuleList( |
|
[ |
|
weight_norm( |
|
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), |
|
name="weight", |
|
) for _ in range(config.num_vq) |
|
] |
|
) |
|
dvae = DVAE() |
|
self.dvae = dvae |
|
|
|
model_config = LlamaConfig( |
|
hidden_size=config.hidden_size, |
|
intermediate_size=config.intermediate_size, |
|
num_attention_heads=config.num_attention_heads, |
|
num_hidden_layers=config.num_hidden_layers, |
|
max_position_embeddings=config.max_position_embeddings, |
|
attn_implementation=config.attn_implementation, |
|
) |
|
|
|
model = LlamaModel(model_config) |
|
self.model = model |
|
|
|
return |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
lm_spk_emb_last_hidden_states=None, |
|
lm_last_hidden_states=None, |
|
target_audio_features=None, |
|
streaming_tts_text_masks=None, |
|
**kwargs, |
|
): |
|
""" |
|
Calculate TTS modeling loss. Only used in training. |
|
|
|
Process: |
|
- LLM last hidden states (obtained from LLM, with gradients) |
|
- Text ground truth (without gradients) |
|
- Target audio features (without gradients) |
|
|
|
Updates: |
|
- 2024/10/3: Support empty input (dummy train) for tasks without audio, preventing training stalls due to unused parameters. |
|
- 2024/10/11: Support EOS token |
|
|
|
Args: |
|
input_ids (List[Tensor[seq_len]]): Text ground truth input_ids for each model's speech area. Each element is a variable-length Tensor. |
|
lm_spk_emb_last_hidden_states (List[Tensor[gpt_dim]], optional): Speaker embedding last hidden states from the language model. |
|
lm_last_hidden_states (List[Tensor[seq_len, gpt_dim]], optional): LLM last hidden states for each model's speech area. Each element is a variable-length Tensor. |
|
target_audio_features (List[Tensor[num_channels, num_samples]], optional): Mel spectrogram ground truth for each model's speech area. Each element is a variable-length Tensor. |
|
streaming_tts_text_masks (List[Tensor[seq_len_max]], optional): Masks used to pad text to fixed length in streaming training. Shape is Tensor[seq_len_max]. |
|
""" |
|
|
|
|
|
dummy = False |
|
if self.train: |
|
if len(input_ids) == 0: |
|
dummy = True |
|
dummy_seq_len = 100 |
|
input_ids = [ |
|
torch.full( |
|
(dummy_seq_len,), |
|
fill_value=1, |
|
device=self.model.embed_tokens.weight.device, |
|
dtype=torch.int64 |
|
) |
|
] |
|
input_ids[0][0: self.num_spk_embs] = self.spk_emb_token_id |
|
|
|
if self.config.use_speaker_embedding: |
|
lm_spk_emb_last_hidden_states = [ |
|
torch.full( |
|
(self.num_spk_embs, self.config.llm_dim), |
|
fill_value=0, |
|
device=self.model.embed_tokens.weight.device, |
|
dtype=self.model.embed_tokens.weight.dtype |
|
) |
|
] |
|
else: |
|
lm_last_hidden_states = [ |
|
torch.full( |
|
(dummy_seq_len, self.config.llm_dim), |
|
fill_value=0, |
|
device=self.model.embed_tokens.weight.device, |
|
dtype=self.model.embed_tokens.weight.dtype |
|
) |
|
] |
|
|
|
target_audio_features = [ |
|
torch.full( |
|
(self.num_mel_bins, dummy_seq_len), |
|
fill_value=0, |
|
device=self.model.embed_tokens.weight.device, |
|
dtype=self.model.embed_tokens.weight.dtype |
|
) |
|
] |
|
streaming_tts_text_masks = None |
|
|
|
if lm_last_hidden_states is not None: |
|
assert not self.use_speaker_embedding |
|
|
|
|
|
assert len(lm_last_hidden_states) != 0 |
|
all_tts_condition_seq_len = [i.shape[0] for i in lm_last_hidden_states] |
|
|
|
|
|
input_data = pad_sequence(lm_last_hidden_states, batch_first=True) |
|
|
|
|
|
all_tts_condition = self.projector(input_data) |
|
|
|
|
|
all_tts_condition = F.normalize(all_tts_condition, p=2, dim=2) |
|
|
|
|
|
all_tts_condition_varlen = [] |
|
for idx in range(all_tts_condition.shape[0]): |
|
all_tts_condition_varlen.append(all_tts_condition[idx, 0:all_tts_condition_seq_len[idx]]) |
|
|
|
else: |
|
all_tts_condition_varlen = None |
|
|
|
if lm_spk_emb_last_hidden_states is not None: |
|
assert self.use_speaker_embedding |
|
if len(lm_spk_emb_last_hidden_states) == 0: |
|
raise ValueError("lm_spk_emb_last_hidden_states is empty.") |
|
|
|
stacked_lm_spk_emb_last_hidden_states = torch.stack(lm_spk_emb_last_hidden_states, dim=0) |
|
|
|
|
|
assert stacked_lm_spk_emb_last_hidden_states.shape[1] == self.num_spk_embs |
|
|
|
|
|
gpt_spk_emb_last_hidden_states = self.projector( |
|
stacked_lm_spk_emb_last_hidden_states) |
|
|
|
|
|
gpt_spk_emb_last_hidden_states = F.normalize(gpt_spk_emb_last_hidden_states, p=2, dim=-1) |
|
|
|
else: |
|
gpt_spk_emb_last_hidden_states = None |
|
|
|
|
|
if target_audio_features is not None: |
|
assert self.dvae.coef.requires_grad == False |
|
with torch.inference_mode(): |
|
eos_token_id = int(self.emb_code[0].num_embeddings - 1) |
|
all_audio_codes = [] |
|
|
|
with torch.cuda.amp.autocast(dtype=torch.float): |
|
for audio_waveform in target_audio_features: |
|
audio_codes = self.dvae(audio_waveform, mode="encode") |
|
|
|
audio_codes_with_eos = torch.cat( |
|
( |
|
audio_codes.squeeze(0), |
|
torch.ones(self.num_vq, 1, device=audio_codes.device, |
|
dtype=audio_codes.dtype) * eos_token_id |
|
), dim=-1 |
|
) |
|
all_audio_codes.append(audio_codes_with_eos) |
|
|
|
all_audio_codes_seq_len = [i.shape[1] for i in all_audio_codes] |
|
|
|
|
|
audio_embed_all_layers = [] |
|
for i in range(self.num_vq): |
|
audio_codes_layer_i = [] |
|
for codes in all_audio_codes: |
|
audio_codes_layer_i.append( |
|
codes[i, :].squeeze(0), |
|
) |
|
|
|
audio_codes_layer_i = pad_sequence(audio_codes_layer_i, batch_first=True) |
|
|
|
audio_embed_layer_i = self.emb_code[i](audio_codes_layer_i) |
|
audio_embed_all_layers.append(audio_embed_layer_i) |
|
|
|
|
|
|
|
audio_embed_all_layers = torch.stack(audio_embed_all_layers, dim=0) |
|
audio_embed_all_layers = torch.sum(audio_embed_all_layers, dim=0, |
|
keepdim=False) |
|
|
|
|
|
audio_embed_all_layers_varlen = [] |
|
for idx in range(audio_embed_all_layers.shape[0]): |
|
audio_embed_all_layers_varlen.append( |
|
audio_embed_all_layers[idx, 0:all_audio_codes_seq_len[idx]] |
|
) |
|
|
|
|
|
all_input_ids_seq_len = [i.shape[0] for i in input_ids] |
|
input_ids = pad_sequence(input_ids, batch_first=True) |
|
all_text_embeds = self.emb_text(input_ids) |
|
|
|
|
|
if lm_spk_emb_last_hidden_states is not None: |
|
|
|
apply_spk_emb( |
|
input_ids=input_ids, |
|
spk_emb=gpt_spk_emb_last_hidden_states, |
|
input_embeds=all_text_embeds, |
|
spk_emb_token_id=self.spk_emb_token_id, |
|
num_spk_embs=self.num_spk_embs, |
|
) |
|
|
|
all_text_embeds_varlen = [] |
|
|
|
for idx in range(all_text_embeds.shape[0]): |
|
all_text_embeds_varlen.append( |
|
all_text_embeds[idx, 0:all_input_ids_seq_len[idx], :] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
embeds_to_merge = [] |
|
|
|
|
|
if lm_last_hidden_states is not None: |
|
embeds_to_merge.append(all_tts_condition_varlen) |
|
|
|
|
|
if self.use_text: |
|
embeds_to_merge.append(all_text_embeds_varlen) |
|
|
|
|
|
if target_audio_features is not None: |
|
embeds_to_merge.append(audio_embed_all_layers_varlen) |
|
|
|
|
|
all_merged_embeds_ = [] |
|
for item_tuple in zip(*embeds_to_merge): |
|
|
|
merged_embed = torch.cat(item_tuple, dim=0) |
|
all_merged_embeds_.append(merged_embed) |
|
|
|
input_embeds_seqlen = [] |
|
for i in all_merged_embeds_: |
|
input_embeds_seqlen.append(i.shape[0]) |
|
|
|
|
|
|
|
input_embeds = pad_sequence(all_merged_embeds_, |
|
batch_first=True) |
|
|
|
|
|
text_ranges = [] |
|
batch_size = input_embeds.shape[0] |
|
for idx in range(batch_size): |
|
start_idx = 0 |
|
|
|
|
|
if lm_last_hidden_states is not None: |
|
start_idx += all_tts_condition_seq_len[idx] |
|
|
|
end_idx = start_idx + all_input_ids_seq_len[idx] |
|
text_ranges.append((start_idx, end_idx)) |
|
|
|
if target_audio_features is not None: |
|
|
|
batch_size = input_embeds.shape[0] |
|
seq_len_max = input_embeds.shape[1] |
|
|
|
|
|
labels = torch.zeros(batch_size, seq_len_max, self.num_vq, device=input_embeds.device, dtype=torch.long) |
|
labels[:, :, :] = -100 |
|
|
|
|
|
audio_codes_ranges = [] |
|
for idx in range(batch_size): |
|
start_idx = 0 |
|
|
|
|
|
if lm_last_hidden_states is not None: |
|
start_idx += all_tts_condition_seq_len[idx] |
|
|
|
if self.use_text: |
|
start_idx += all_input_ids_seq_len[idx] |
|
|
|
end_idx = start_idx + all_audio_codes_seq_len[idx] |
|
audio_codes_ranges.append((start_idx, end_idx)) |
|
|
|
|
|
for idx, audio_codes_range in zip(range(batch_size), audio_codes_ranges): |
|
start_idx = audio_codes_range[0] |
|
end_idx = audio_codes_range[1] |
|
labels[ |
|
idx, start_idx: end_idx, : |
|
] = all_audio_codes[idx].permute(1, 0) |
|
|
|
|
|
|
|
|
|
|
|
if self.streaming and not dummy: |
|
tts_attention_mask_4d = make_streaming_chunk_mask( |
|
input_embeds=input_embeds, |
|
tts_text_scopes=text_ranges, |
|
tts_audio_scopes=audio_codes_ranges, |
|
tts_text_masks=streaming_tts_text_masks, |
|
min_chunk_num_token=self.streaming_text_chunk_min, |
|
max_chunk_num_token=self.streaming_text_chunk_max, |
|
streaming_audio_chunk_size=self.streaming_audio_chunk_size, |
|
) |
|
else: |
|
tts_attention_mask_4d = None |
|
|
|
|
|
|
|
|
|
outputs = self.model( |
|
inputs_embeds=input_embeds, |
|
attention_mask=tts_attention_mask_4d, |
|
) |
|
|
|
tts_last_hidden_state = outputs.last_hidden_state |
|
|
|
|
|
logits_all_vq_layers = [] |
|
for num_vq_iter in range(self.num_vq): |
|
logits_i = self.head_code[num_vq_iter]( |
|
tts_last_hidden_state) |
|
logits_all_vq_layers.append(logits_i) |
|
logits_all_vq_layers = torch.stack(logits_all_vq_layers, |
|
dim=0) |
|
logits_all_vq_layers = logits_all_vq_layers.permute(1, 2, 0, |
|
3) |
|
|
|
|
|
shift_logits = logits_all_vq_layers[:, :-1, :, |
|
:].contiguous() |
|
shift_labels = labels[:, 1:, :].contiguous() |
|
|
|
|
|
if not self.aug_loss_weight: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct( |
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) |
|
) |
|
else: |
|
loss_fct = nn.CrossEntropyLoss(reduction='none') |
|
losses = loss_fct( |
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
shift_labels.view(-1).to(shift_logits.device) |
|
).view(shift_labels.size()) |
|
|
|
valid_label_count = (shift_labels != -100).sum() |
|
|
|
eos_token_id = int(self.dvae.emb_code[0].num_embeddings - 1) |
|
eos_positions = (shift_labels == eos_token_id).nonzero() |
|
for pos in eos_positions: |
|
seq_len = pos[1] + 1 |
|
if seq_len < 400: |
|
losses[pos[0], pos[1], pos[2]] *= 0.2 |
|
elif seq_len > 650: |
|
losses[pos[0], pos[1], pos[2]] *= 2 |
|
|
|
loss = losses.sum() / valid_label_count |
|
|
|
if dummy: |
|
print("dummy loss", loss) |
|
loss = loss * 0 |
|
|
|
else: |
|
loss = None |
|
|
|
return loss |
|
|
|
@torch.inference_mode() |
|
def prepare_inputs_embeds( |
|
self, |
|
input_ids: torch.Tensor, |
|
lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, |
|
lm_last_hidden_states: Optional[torch.Tensor] = None |
|
): |
|
"""Prepare inputs_embeds for the model in inference mode, |
|
encode input_ids to embeddings, then merge lm_spk_emb_last_hidden_states, and lm_last_hidden_states. |
|
|
|
Args: |
|
input_ids (torch.Tensor): Input token IDs. |
|
lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None. |
|
lm_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states from the language model. Defaults to None. |
|
|
|
Raises: |
|
NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented. |
|
|
|
Returns: |
|
torch.Tensor: Prepared input embeddings for the model. |
|
""" |
|
assert input_ids.shape[0] == 1 |
|
|
|
|
|
inputs_embeds = self.emb_text(input_ids) |
|
|
|
|
|
if self.use_speaker_embedding: |
|
spk_emb_mask = input_ids == self.spk_emb_token_id |
|
if spk_emb_mask.any(): |
|
assert lm_spk_emb_last_hidden_states is not None |
|
|
|
lm_spk_emb_last_hidden_states = lm_spk_emb_last_hidden_states.to(self.projector.linear1.weight.dtype) |
|
projected_spk_emb = self.projector(lm_spk_emb_last_hidden_states) |
|
projected_spk_emb = F.normalize(projected_spk_emb, p=2, dim=-1) |
|
apply_spk_emb( |
|
input_ids=input_ids, |
|
spk_emb=projected_spk_emb, |
|
input_embeds=inputs_embeds, |
|
spk_emb_token_id=self.spk_emb_token_id, |
|
num_spk_embs=self.num_spk_embs |
|
) |
|
else: |
|
assert lm_last_hidden_states is not None |
|
|
|
raise NotImplementedError |
|
|
|
return inputs_embeds |
|
|
|
@torch.inference_mode() |
|
def prefill_text( |
|
self, |
|
input_ids: torch.Tensor, |
|
position_ids: torch.LongTensor, |
|
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
|
lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, |
|
lm_last_hidden_states: Optional[torch.Tensor] = None |
|
): |
|
"""Prefill a chunk of new text tokens in streaming setting. |
|
Specifically speaking, update `past_key_values` using new text tokens. |
|
|
|
Args: |
|
input_ids (Tensor): Tensor of shape [batch_size, seq_len] |
|
position_ids (LongTensor): Tensor of shape [batch_size, seq_len] |
|
past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated. |
|
lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None. |
|
lm_last_hidden_states (Tensor, optional): _description_. Defaults to None. |
|
|
|
Note that all `batch_size` should be `1`. |
|
""" |
|
assert input_ids.shape[0] == 1 |
|
assert past_key_values is not None |
|
|
|
|
|
inputs_embeds = self.prepare_inputs_embeds( |
|
input_ids=input_ids, |
|
lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, |
|
lm_last_hidden_states=lm_last_hidden_states, |
|
) |
|
|
|
|
|
past_key_values_for_prefill = [] |
|
for i in range(len(past_key_values)): |
|
past_key_values_for_prefill.append( |
|
( |
|
past_key_values[i][0][:, :, :position_ids[:, 0], :].clone(), |
|
past_key_values[i][1][:, :, :position_ids[:, 0], :].clone(), |
|
) |
|
) |
|
|
|
|
|
outputs_prefill: BaseModelOutputWithPast = self.model( |
|
attention_mask=None, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values_for_prefill, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=True, |
|
output_attentions=False, |
|
cache_position=position_ids, |
|
) |
|
|
|
|
|
past_key_values_for_prefill_updated = outputs_prefill.past_key_values |
|
|
|
|
|
for layer_idx in range(len(past_key_values)): |
|
|
|
past_key_values[layer_idx][0][:, :, position_ids[:, 0]:position_ids[:, -1] + 1, :] = \ |
|
past_key_values_for_prefill_updated[layer_idx][0][:, :, position_ids[:, 0]:position_ids[:, -1] + 1].clone() |
|
|
|
past_key_values[layer_idx][1][:, :, position_ids[:, 0]:position_ids[:, -1] + 1, :] = \ |
|
past_key_values_for_prefill_updated[layer_idx][1][:, :, position_ids[:, 0]:position_ids[:, -1] + 1].clone() |
|
|
|
|
|
|
|
|
|
return past_key_values |
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, |
|
input_ids: torch.Tensor, |
|
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
|
temperature: torch.Tensor, |
|
eos_token: Union[int, torch.Tensor], |
|
streaming_tts_text_mask=None, |
|
force_no_stop=False, |
|
min_new_token=10, |
|
max_new_token=50, |
|
logits_warpers: List[LogitsWarper] = [], |
|
logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], |
|
show_tqdm=False, |
|
): |
|
"""Generate audio codes in streaming setting. |
|
Specifically speaking, generate audio codes when not all text tokens are prefilled. |
|
|
|
Usage: |
|
Always pass an non-empty `past_key_values` to the function. The function does not do `prefill` by itself. It relies on `prefill_text` method to provide a valid `past_key_values`. |
|
|
|
1. Create an empty `past_key_values` with |
|
```python |
|
initial_kv_cache_length = 1 + self.num_spk_embs + self.streaming_text_reserved_len |
|
dtype = model.emb_text.weight.dtype |
|
device = model.emb_text.weight.device |
|
past_key_values = [ |
|
( |
|
torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device), |
|
torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device) |
|
) |
|
for _ in range(model.config.num_hidden_layers) |
|
] |
|
|
|
2. Prefill some text tokens using `prefill_text` method. |
|
```python |
|
outputs = llm.generate(**kwargs) |
|
lm_spk_emb_last_hidden_states or lm_last_hidden_states = extract(outputs.last_hidden_states) |
|
input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens)) |
|
position_ids = torch.arange(begin, end, dtype=torch.long, device=device) |
|
past_key_values = self.prefill_text( |
|
input_ids=input_ids, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, |
|
lm_last_hidden_states=lm_last_hidden_states, |
|
) |
|
``` |
|
|
|
3. Generate audio codes using `generate` method. |
|
```python |
|
# initialize input_ids, this should be only done `once` |
|
condition_length = 1 + model.num_spk_embs * model.use_speaker_embedding + model.streaming_text_reserved_len + 1 |
|
input_ids = torch.zeros(batch_size=1, condition_length, self.num_vq) |
|
|
|
outputs = self.generate( |
|
input_ids=input_ids, |
|
past_key_values=past_key_values, |
|
) |
|
|
|
# update past_key_values and input_ids |
|
past_key_values = outputs.past_key_values |
|
input_ids = outputs.input_ids |
|
``` |
|
|
|
4. Repeat step 2 and 3. |
|
|
|
Args: |
|
input_ids (torch.Tensor): Input token ids. |
|
past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism. |
|
temperature (torch.Tensor): Temperature for sampling. |
|
eos_token (Union[int, torch.Tensor]): End of sequence token. |
|
streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None. |
|
max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50. |
|
logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to []. |
|
logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to []. |
|
show_tqdm (bool, optional): Whether to show progress bar. Defaults to True. |
|
Raises: |
|
NotImplementedError: _description_ |
|
Returns: |
|
GenerationOutputs: Generation outputs. |
|
""" |
|
|
|
|
|
assert input_ids.shape[0] == 1 |
|
assert past_key_values is not None |
|
|
|
|
|
|
|
start_idx = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1 |
|
|
|
finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool() |
|
|
|
temperature = ( |
|
temperature.unsqueeze(0) |
|
.expand(input_ids.shape[0], -1) |
|
.contiguous() |
|
.view(-1, 1) |
|
) |
|
|
|
progress = input_ids.shape[1] |
|
|
|
|
|
input_ids_buf = torch.zeros( |
|
input_ids.shape[0], |
|
progress + max_new_token, |
|
input_ids.shape[2], |
|
dtype=input_ids.dtype, |
|
device=input_ids.device, |
|
) |
|
|
|
|
|
input_ids_buf.narrow(1, 0, progress).copy_(input_ids) |
|
|
|
del input_ids |
|
input_ids = input_ids_buf.narrow(1, 0, progress) |
|
|
|
pbar: Optional[tqdm] = None |
|
if show_tqdm: |
|
pbar = tqdm( |
|
total=max_new_token, |
|
desc="code", |
|
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", |
|
) |
|
|
|
condition_length = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1 |
|
|
|
for i in range(max_new_token): |
|
|
|
audio_bos = False |
|
|
|
if progress == condition_length: |
|
audio_bos = True |
|
|
|
if audio_bos: |
|
|
|
assert progress == (past_key_values[0][0].shape[2] + 1) |
|
narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device) |
|
inputs_embeds = self.emb_text(narrowed_input_ids) |
|
del narrowed_input_ids |
|
else: |
|
|
|
assert progress == (past_key_values[0][0].shape[2] + 1) |
|
narrowed_input_ids = input_ids.narrow(dim=1, start=input_ids.shape[1] - 1, length=1) |
|
code_emb = [ |
|
self.emb_code[i](narrowed_input_ids[:, :, i]) |
|
for i in range(self.num_vq) |
|
] |
|
inputs_embeds = torch.stack(code_emb, 3).sum(3) |
|
|
|
position_ids = torch.tensor( |
|
[past_key_values[0][0].shape[2] + 1], |
|
dtype=torch.long, |
|
device=self.device |
|
).unsqueeze(0) |
|
|
|
cache_position = position_ids.clone() |
|
causal_mask = make_streaming_chunk_mask_generation( |
|
inputs_embeds=inputs_embeds, |
|
past_seen_tokens=past_key_values[0][0].shape[2], |
|
streaming_tts_text_mask=streaming_tts_text_mask, |
|
streaming_reserved_length=self.streaming_text_reserved_len, |
|
streaming_text_chunk_size=self.streaming_text_chunk_size |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs: BaseModelOutputWithPast = self.model( |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=True, |
|
output_attentions=False, |
|
cache_position=cache_position, |
|
) |
|
|
|
del position_ids |
|
del inputs_embeds |
|
del cache_position |
|
del causal_mask |
|
|
|
hidden_states = outputs.last_hidden_state |
|
past_key_values = outputs.past_key_values |
|
|
|
with P.cached(): |
|
logits = torch.empty( |
|
hidden_states.size(0), |
|
hidden_states.size(1), |
|
self.num_audio_tokens, |
|
self.num_vq, |
|
dtype=torch.float, |
|
device=self.device, |
|
) |
|
for num_vq_iter in range(self.num_vq): |
|
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) |
|
logits[..., num_vq_iter] = x |
|
del x |
|
|
|
del hidden_states |
|
|
|
|
|
logits = logits.narrow(1, -1, 1).squeeze_(1).float() |
|
|
|
|
|
logits = logits.permute(0, 2, 1) |
|
logits = logits.reshape(-1, logits.size(2)) |
|
|
|
input_ids_sliced = input_ids.narrow( |
|
1, |
|
start_idx, |
|
input_ids.size(1) - start_idx, |
|
).permute(0, 2, 1) |
|
logits_token = input_ids_sliced.reshape( |
|
input_ids_sliced.size(0) * input_ids_sliced.size(1), |
|
-1, |
|
).to(self.device) |
|
del input_ids_sliced |
|
|
|
logits /= temperature |
|
|
|
if not audio_bos: |
|
for logitsProcessors in logits_processors: |
|
logits = logitsProcessors(logits_token, logits) |
|
if not audio_bos: |
|
for logitsWarpers in logits_warpers: |
|
logits = logitsWarpers(logits_token, logits) |
|
|
|
del logits_token |
|
|
|
if i < min_new_token: |
|
logits[:, eos_token] = -torch.inf |
|
|
|
if force_no_stop: |
|
logits[:, eos_token] = -torch.inf |
|
|
|
scores = F.softmax(logits, dim=-1) |
|
|
|
del logits |
|
|
|
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) |
|
|
|
del scores |
|
|
|
|
|
idx_next = idx_next.view(-1, self.num_vq) |
|
finish_or = idx_next.eq(eos_token).any(1) |
|
finish.logical_or_(finish_or) |
|
|
|
del finish_or |
|
|
|
input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1)) |
|
|
|
if i == 0 and finish.any(): |
|
|
|
break |
|
|
|
del idx_next |
|
progress += 1 |
|
input_ids = input_ids_buf.narrow(1, 0, progress) |
|
|
|
if finish.all(): |
|
break |
|
|
|
if pbar is not None: |
|
pbar.update(1) |
|
|
|
if pbar is not None: |
|
pbar.close() |
|
|
|
if not finish.all(): |
|
if show_tqdm: |
|
print( |
|
f"incomplete result. hit max_new_token: {max_new_token}" |
|
) |
|
|
|
del input_ids_buf |
|
|
|
if finish.all(): |
|
|
|
genrated_input_ids = input_ids[:, condition_length:-1, :] |
|
else: |
|
|
|
genrated_input_ids = input_ids[:, condition_length:, :] |
|
|
|
return ConditionalChatTTSGenerationOutput( |
|
new_ids=genrated_input_ids, |
|
audio_input_ids=input_ids, |
|
past_key_values=past_key_values, |
|
finished=finish.all(), |
|
) |
|
|
|
@torch.inference_mode() |
|
def decode_to_mel_specs( |
|
self, |
|
result_list: List[torch.Tensor], |
|
use_decoder: bool = False, |
|
): |
|
decoder = self.dvae |
|
max_x_len = -1 |
|
if len(result_list) == 0: |
|
return np.array([], dtype=np.float32) |
|
for result in result_list: |
|
if result.size(0) > max_x_len: |
|
max_x_len = result.size(0) |
|
batch_result = torch.zeros( |
|
(len(result_list), result_list[0].size(1), max_x_len), |
|
dtype=result_list[0].dtype, |
|
device=result_list[0].device, |
|
) |
|
for i in range(len(result_list)): |
|
src = result_list[i] |
|
batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0)) |
|
del src |
|
|
|
mel_specs = decoder(batch_result) |
|
del batch_result |
|
return mel_specs |
|
|
|
|
|
def gen_logits( |
|
num_code: int, |
|
top_P=0.7, |
|
top_K=20, |
|
repetition_penalty=1.0, |
|
): |
|
logits_warpers = [] |
|
if top_P is not None: |
|
logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) |
|
if top_K is not None: |
|
logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) |
|
|
|
logits_processors = [] |
|
if repetition_penalty is not None and repetition_penalty != 1: |
|
logits_processors.append( |
|
CustomRepetitionPenaltyLogitsProcessorRepeat( |
|
repetition_penalty, num_code, 16 |
|
) |
|
) |
|
|
|
return logits_warpers, logits_processors |
|
|