Spaces:
Sleeping
Sleeping
import os | |
from typing import List, Union | |
import numpy as np | |
import math | |
import time | |
import heapq | |
import torch | |
from torch import Tensor, nn | |
from torch.distributions.distribution import Distribution | |
from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer | |
import random | |
from typing import Optional | |
from .tools.token_emb import NewTokenEmb | |
class MLM(nn.Module): | |
def __init__( | |
self, | |
model_path: str, | |
model_type: str = "t5", | |
stage: str = "lm_pretrain", | |
new_token_type: str = "insert", | |
motion_codebook_size: int = 512, | |
framerate: float = 20.0, | |
down_t: int = 4, | |
predict_ratio: float = 0.2, | |
inbetween_ratio: float = 0.25, | |
max_length: int = 256, | |
lora: bool = False, | |
quota_ratio: float = 0.5, | |
noise_density: float = 0.15, | |
mean_noise_span_length: int = 3, | |
**kwargs, | |
) -> None: | |
super().__init__() | |
# Parameters | |
self.m_codebook_size = motion_codebook_size | |
self.max_length = max_length | |
self.framerate = framerate | |
self.down_t = down_t | |
self.predict_ratio = predict_ratio | |
self.inbetween_ratio = inbetween_ratio | |
self.noise_density = noise_density | |
self.mean_noise_span_length = mean_noise_span_length | |
self.quota_ratio = quota_ratio | |
self.stage = stage | |
# Instantiate language model | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=True) | |
if model_type == "t5": | |
self.language_model = T5ForConditionalGeneration.from_pretrained( | |
model_path) | |
self.lm_type = 'encdec' | |
elif model_type == "gpt2": | |
self.language_model = GPT2LMHeadModel.from_pretrained(model_path) | |
self.lm_type = 'dec' | |
else: | |
raise ValueError("type must be either seq2seq or conditional") | |
if self.lm_type == 'dec': | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Add motion tokens | |
self.tokenizer.add_tokens( | |
[f'<motion_id_{i}>' for i in range(self.m_codebook_size + 3)]) | |
if new_token_type == "insert": | |
self.language_model.resize_token_embeddings(len(self.tokenizer)) | |
elif new_token_type == "mlp": | |
shared = NewTokenEmb(self.language_model.shared, | |
self.m_codebook_size + 3) | |
# lm_head = NewTokenEmb(self.language_model.lm_head, | |
# self.m_codebook_size + 3) | |
self.language_model.resize_token_embeddings(len(self.tokenizer)) | |
self.language_model.shared = shared | |
# self.language_model.lm_head = lm_head | |
# Lora | |
if lora: | |
from peft import LoraConfig, TaskType, get_peft_model, get_peft_model_state_dict | |
from peft.utils.other import fsdp_auto_wrap_policy | |
peft_config = LoraConfig( | |
bias="none", | |
task_type="CAUSAL_LM", | |
# inference_mode=False, | |
r=8, | |
lora_alpha=16, | |
lora_dropout=0.05) | |
self.language_model = get_peft_model(self.language_model, | |
peft_config) | |
def forward(self, texts: List[str], motion_tokens: Tensor, | |
lengths: List[int], tasks: dict): | |
if self.lm_type == 'encdec': | |
return self.forward_encdec(texts, motion_tokens, lengths, tasks) | |
elif self.lm_type == 'dec': | |
return self.forward_dec(texts, motion_tokens, lengths, tasks) | |
else: | |
raise NotImplementedError("Only conditional_multitask supported") | |
def forward_encdec( | |
self, | |
texts: List[str], | |
motion_tokens: Tensor, | |
lengths: List[int], | |
tasks: dict, | |
): | |
# Tensor to string | |
motion_strings = self.motion_token_to_string(motion_tokens, lengths) | |
# Supervised or unsupervised | |
# condition = random.choice( | |
# ['text', 'motion', 'supervised', 'supervised', 'supervised']) | |
condition = random.choice(['supervised', 'supervised', 'supervised']) | |
if condition == 'text': | |
inputs = texts | |
outputs = texts | |
elif condition == 'motion': | |
inputs = motion_strings | |
outputs = motion_strings | |
else: | |
inputs, outputs = self.template_fulfill(tasks, lengths, | |
motion_strings, texts) | |
# Tokenize | |
source_encoding = self.tokenizer(inputs, | |
padding='max_length', | |
max_length=self.max_length, | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt") | |
source_attention_mask = source_encoding.attention_mask.to( | |
motion_tokens.device) | |
source_input_ids = source_encoding.input_ids.to(motion_tokens.device) | |
if condition in ['text', 'motion']: | |
batch_size, expandend_input_length = source_input_ids.shape | |
mask_indices = np.asarray([ | |
self.random_spans_noise_mask(expandend_input_length) | |
for i in range(batch_size) | |
]) | |
target_mask = ~mask_indices | |
input_ids_sentinel = self.create_sentinel_ids( | |
mask_indices.astype(np.int8)) | |
target_sentinel = self.create_sentinel_ids( | |
target_mask.astype(np.int8)) | |
labels_input_ids = self.filter_input_ids(source_input_ids, | |
target_sentinel) | |
source_input_ids = self.filter_input_ids(source_input_ids, | |
input_ids_sentinel) | |
else: | |
target_inputs = self.tokenizer(outputs, | |
padding='max_length', | |
max_length=self.max_length, | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt") | |
labels_input_ids = target_inputs.input_ids.to(motion_tokens.device) | |
lables_attention_mask = target_inputs.attention_mask.to( | |
motion_tokens.device) | |
labels_input_ids[labels_input_ids == 0] = -100 | |
outputs = self.language_model( | |
input_ids=source_input_ids, | |
attention_mask=source_attention_mask | |
if condition == 'supervised' else None, | |
labels=labels_input_ids, | |
decoder_attention_mask=lables_attention_mask | |
if condition == 'supervised' else None, | |
) | |
return outputs | |
def forward_dec( | |
self, | |
texts: List[str], | |
motion_tokens: Tensor, | |
lengths: List[int], | |
tasks: dict, | |
): | |
self.tokenizer.padding_side = "right" | |
# Tensor to string | |
motion_strings = self.motion_token_to_string(motion_tokens, lengths) | |
# Supervised or unsupervised | |
condition = random.choice( | |
['text', 'motion', 'supervised', 'supervised', 'supervised']) | |
if condition == 'text': | |
labels = texts | |
elif condition == 'motion': | |
labels = motion_strings | |
else: | |
inputs, outputs = self.template_fulfill(tasks, lengths, | |
motion_strings, texts) | |
labels = [] | |
for i in range(len(inputs)): | |
labels.append(inputs[i] + ' \n ' + outputs[i] + | |
self.tokenizer.eos_token) | |
# Tokenize | |
inputs = self.tokenizer(labels, | |
padding='max_length', | |
max_length=self.max_length, | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors="pt") | |
labels_input_ids = inputs.input_ids.to(motion_tokens.device) | |
lables_attention_mask = inputs.attention_mask.to(motion_tokens.device) | |
# print(labels_input_ids[0:5]) | |
outputs = self.language_model(input_ids=labels_input_ids, | |
attention_mask=lables_attention_mask, | |
labels=inputs["input_ids"]) | |
return outputs | |
def generate_direct(self, | |
texts: List[str], | |
max_length: int = 256, | |
num_beams: int = 1, | |
do_sample: bool = True, | |
bad_words_ids: List[int] = None): | |
# Device | |
self.device = self.language_model.device | |
# Tokenize | |
if self.lm_type == 'dec': | |
texts = [text + " \n " for text in texts] | |
source_encoding = self.tokenizer(texts, | |
padding='max_length', | |
max_length=self.max_length, | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt") | |
source_input_ids = source_encoding.input_ids.to(self.device) | |
source_attention_mask = source_encoding.attention_mask.to(self.device) | |
if self.lm_type == 'encdec': | |
outputs = self.language_model.generate( | |
source_input_ids, | |
max_length=max_length, | |
num_beams=num_beams, | |
do_sample=do_sample, | |
bad_words_ids=bad_words_ids, | |
) | |
elif self.lm_type == 'dec': | |
outputs = self.language_model.generate( | |
input_ids=source_input_ids, | |
attention_mask=source_attention_mask, | |
pad_token_id=self.tokenizer.pad_token_id, | |
do_sample=do_sample, | |
max_new_tokens=max_length) | |
self.tokenizer.padding_side = 'left' | |
outputs_string = self.tokenizer.batch_decode(outputs, | |
skip_special_tokens=True) | |
print(texts[:2]) | |
print(outputs_string[:2]) | |
outputs_tokens, cleaned_text = self.motion_string_to_token( | |
outputs_string) | |
return outputs_tokens, cleaned_text | |
def generate_conditional(self, | |
texts: Optional[List[str]] = None, | |
motion_tokens: Optional[Tensor] = None, | |
lengths: Optional[List[int]] = None, | |
task: str = "t2m", | |
with_len: bool = False, | |
stage: str = 'train', | |
tasks: dict = None): | |
self.device = self.language_model.device | |
if task in ["t2m", "m2m", "pred", "inbetween"]: | |
if task == "t2m": | |
assert texts is not None | |
motion_strings = [''] * len(texts) | |
if not with_len: | |
if tasks is None: | |
tasks = [{ | |
'input': | |
['Generate motion: <Caption_Placeholder>'], | |
'output': [''] | |
}] * len(texts) | |
lengths = [0] * len(texts) | |
else: | |
tasks = [{ | |
'input': [ | |
'Generate motion with <Frame_Placeholder> frames: <Caption_Placeholder>' | |
], | |
'output': [''] | |
}] * len(texts) | |
elif task == "pred": | |
assert motion_tokens is not None and lengths is not None | |
texts = [''] * len(lengths) | |
tasks = [{ | |
'input': ['Predict motion: <Motion_Placeholder_s1>'], | |
'output': [''] | |
}] * len(lengths) | |
motion_strings_old = self.motion_token_to_string( | |
motion_tokens, lengths) | |
motion_strings = [] | |
for i, length in enumerate(lengths): | |
split = length // 5 | |
motion_strings.append( | |
'>'.join(motion_strings_old[i].split('>')[:split]) + | |
'>') | |
elif task == "inbetween": | |
assert motion_tokens is not None and lengths is not None | |
texts = [''] * len(lengths) | |
tasks = [{ | |
'input': [ | |
"Complete the masked motion: <Motion_Placeholder_Masked>" | |
], | |
'output': [''] | |
}] * len(lengths) | |
motion_strings = self.motion_token_to_string( | |
motion_tokens, lengths) | |
inputs, outputs = self.template_fulfill(tasks, lengths, | |
motion_strings, texts, | |
stage) | |
outputs_tokens, cleaned_text = self.generate_direct(inputs, | |
max_length=128, | |
num_beams=1, | |
do_sample=True) | |
return outputs_tokens | |
elif task == "m2t": | |
assert motion_tokens is not None and lengths is not None | |
motion_strings = self.motion_token_to_string( | |
motion_tokens, lengths) | |
if not with_len: | |
tasks = [{ | |
'input': ['Generate text: <Motion_Placeholder>'], | |
'output': [''] | |
}] * len(lengths) | |
else: | |
tasks = [{ | |
'input': [ | |
'Generate text with <Frame_Placeholder> frames: <Motion_Placeholder>' | |
], | |
'output': [''] | |
}] * len(lengths) | |
texts = [''] * len(lengths) | |
inputs, outputs = self.template_fulfill(tasks, lengths, | |
motion_strings, texts) | |
outputs_tokens, cleaned_text = self.generate_direct( | |
inputs, | |
max_length=40, | |
num_beams=1, | |
do_sample=False, | |
# bad_words_ids=self.bad_words_ids | |
) | |
return cleaned_text | |
def motion_token_to_string(self, motion_token: Tensor, lengths: List[int]): | |
motion_string = [] | |
for i in range(len(motion_token)): | |
motion_i = motion_token[i].cpu( | |
) if motion_token[i].device.type == 'cuda' else motion_token[i] | |
motion_list = motion_i.tolist()[:lengths[i]] | |
motion_string.append( | |
(f'<motion_id_{self.m_codebook_size}>' + | |
''.join([f'<motion_id_{int(i)}>' for i in motion_list]) + | |
f'<motion_id_{self.m_codebook_size + 1}>')) | |
return motion_string | |
def motion_token_list_to_string(self, motion_token: Tensor): | |
motion_string = [] | |
for i in range(len(motion_token)): | |
motion_i = motion_token[i].cpu( | |
) if motion_token[i].device.type == 'cuda' else motion_token[i] | |
motion_list = motion_i.tolist() | |
motion_string.append( | |
(f'<motion_id_{self.m_codebook_size}>' + | |
''.join([f'<motion_id_{int(i)}>' for i in motion_list]) + | |
f'<motion_id_{self.m_codebook_size + 1}>')) | |
return motion_string | |
def motion_string_to_token(self, motion_string: List[str]): | |
motion_tokens = [] | |
output_string = [] | |
for i in range(len(motion_string)): | |
string = self.get_middle_str( | |
motion_string[i], f'<motion_id_{self.m_codebook_size}>', | |
f'<motion_id_{self.m_codebook_size + 1}>') | |
string_list = string.split('><') | |
token_list = [ | |
int(i.split('_')[-1].replace('>', '')) | |
for i in string_list[1:-1] | |
] | |
if len(token_list) == 0: | |
token_list = [0] | |
token_list_padded = torch.tensor(token_list, | |
dtype=int).to(self.device) | |
motion_tokens.append(token_list_padded) | |
output_string.append(motion_string[i].replace( | |
string, '<Motion_Placeholder>')) | |
return motion_tokens, output_string | |
def placeholder_fulfill(self, prompt: str, length: int, motion_string: str, | |
text: str): | |
seconds = math.floor(length / self.framerate) | |
motion_splited = motion_string.split('>') | |
token_length = length / self.down_t | |
predict_head = int(token_length * self.predict_ratio + 1) | |
masked_head = int(token_length * self.inbetween_ratio + 1) | |
masked_tail = int(token_length * (1 - self.inbetween_ratio) + 1) | |
motion_predict_head = '>'.join( | |
motion_splited[:predict_head] | |
) + f'><motion_id_{self.m_codebook_size+1}>' | |
motion_predict_last = f'<motion_id_{self.m_codebook_size}>' + '>'.join( | |
motion_splited[predict_head:]) | |
motion_masked = '>'.join( | |
motion_splited[:masked_head] | |
) + '>' + f'<motion_id_{self.m_codebook_size+2}>' * ( | |
masked_tail - masked_head) + '>'.join(motion_splited[masked_tail:]) | |
if random.random() < self.quota_ratio: | |
text = f'\"{text}\"' | |
prompt = prompt.replace('<Caption_Placeholder>', text).replace( | |
'<Motion_Placeholder>', | |
motion_string).replace('<Frame_Placeholder>', f'{length}').replace( | |
'<Second_Placeholder>', '%.1f' % seconds).replace( | |
'<Motion_Placeholder_s1>', motion_predict_head).replace( | |
'<Motion_Placeholder_s2>', | |
motion_predict_last).replace( | |
'<Motion_Placeholder_Masked>', motion_masked) | |
return prompt | |
def template_fulfill(self, | |
tasks, | |
lengths, | |
motion_strings, | |
texts, | |
stage='test'): | |
inputs = [] | |
outputs = [] | |
for i in range(len(lengths)): | |
input_template = random.choice(tasks[i]['input']) | |
output_template = random.choice(tasks[i]['output']) | |
length = lengths[i] | |
inputs.append( | |
self.placeholder_fulfill(input_template, length, | |
motion_strings[i], texts[i])) | |
outputs.append( | |
self.placeholder_fulfill(output_template, length, | |
motion_strings[i], texts[i])) | |
return inputs, outputs | |
def get_middle_str(self, content, startStr, endStr): | |
try: | |
startIndex = content.index(startStr) | |
if startIndex >= 0: | |
startIndex += len(startStr) | |
endIndex = content.index(endStr) | |
except: | |
return f'<motion_id_{self.m_codebook_size}><motion_id_0><motion_id_{self.m_codebook_size+1}>' | |
return f'<motion_id_{self.m_codebook_size}>' + content[ | |
startIndex:endIndex] + f'<motion_id_{self.m_codebook_size+1}>' | |
def random_spans_noise_mask(self, length): | |
# From https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py | |
orig_length = length | |
num_noise_tokens = int(np.round(length * self.noise_density)) | |
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. | |
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) | |
num_noise_spans = int( | |
np.round(num_noise_tokens / self.mean_noise_span_length)) | |
# avoid degeneracy by ensuring positive number of noise spans | |
num_noise_spans = max(num_noise_spans, 1) | |
num_nonnoise_tokens = length - num_noise_tokens | |
# pick the lengths of the noise spans and the non-noise spans | |
def _random_segmentation(num_items, num_segments): | |
"""Partition a sequence of items randomly into non-empty segments. | |
Args: | |
num_items: an integer scalar > 0 | |
num_segments: an integer scalar in [1, num_items] | |
Returns: | |
a Tensor with shape [num_segments] containing positive integers that add | |
up to num_items | |
""" | |
mask_indices = np.arange(num_items - 1) < (num_segments - 1) | |
np.random.shuffle(mask_indices) | |
first_in_segment = np.pad(mask_indices, [[1, 0]]) | |
segment_id = np.cumsum(first_in_segment) | |
# count length of sub segments assuming that list is sorted | |
_, segment_length = np.unique(segment_id, return_counts=True) | |
return segment_length | |
noise_span_lengths = _random_segmentation(num_noise_tokens, | |
num_noise_spans) | |
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, | |
num_noise_spans) | |
interleaved_span_lengths = np.reshape( | |
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), | |
[num_noise_spans * 2], | |
) | |
span_starts = np.cumsum(interleaved_span_lengths)[:-1] | |
span_start_indicator = np.zeros((length, ), dtype=np.int8) | |
span_start_indicator[span_starts] = True | |
span_num = np.cumsum(span_start_indicator) | |
is_noise = np.equal(span_num % 2, 1) | |
return is_noise[:orig_length] | |
def create_sentinel_ids(self, mask_indices): | |
# From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py | |
start_indices = mask_indices - np.roll(mask_indices, 1, | |
axis=-1) * mask_indices | |
start_indices[:, 0] = mask_indices[:, 0] | |
sentinel_ids = np.where(start_indices != 0, | |
np.cumsum(start_indices, axis=-1), | |
start_indices) | |
sentinel_ids = np.where(sentinel_ids != 0, | |
(len(self.tokenizer) - sentinel_ids), 0) | |
sentinel_ids -= mask_indices - start_indices | |
return sentinel_ids | |
def filter_input_ids(self, input_ids, sentinel_ids): | |
# From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py | |
batch_size = input_ids.shape[0] | |
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, | |
input_ids.to('cpu')) | |
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are | |
# masked tokens coming after sentinel tokens and should be removed | |
input_ids = input_ids_full[input_ids_full >= 0].reshape( | |
(batch_size, -1)) | |
input_ids = np.concatenate( | |
[ | |
input_ids, | |
np.full((batch_size, 1), | |
self.tokenizer.eos_token_id, | |
dtype=np.int32), | |
], | |
axis=-1, | |
) | |
input_ids = torch.tensor(input_ids, device=self.device) | |
return input_ids | |