MotionGPT0 / mGPT /archs /mgpt_lm.py
bill-jiang's picture
Init
4409449
raw
history blame
No virus
24.1 kB
import os
from typing import List, Union
import numpy as np
import math
import time
import heapq
import torch
from torch import Tensor, nn
from torch.distributions.distribution import Distribution
from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer
import random
from typing import Optional
from .tools.token_emb import NewTokenEmb
class MLM(nn.Module):
def __init__(
self,
model_path: str,
model_type: str = "t5",
stage: str = "lm_pretrain",
new_token_type: str = "insert",
motion_codebook_size: int = 512,
framerate: float = 20.0,
down_t: int = 4,
predict_ratio: float = 0.2,
inbetween_ratio: float = 0.25,
max_length: int = 256,
lora: bool = False,
quota_ratio: float = 0.5,
noise_density: float = 0.15,
mean_noise_span_length: int = 3,
**kwargs,
) -> None:
super().__init__()
# Parameters
self.m_codebook_size = motion_codebook_size
self.max_length = max_length
self.framerate = framerate
self.down_t = down_t
self.predict_ratio = predict_ratio
self.inbetween_ratio = inbetween_ratio
self.noise_density = noise_density
self.mean_noise_span_length = mean_noise_span_length
self.quota_ratio = quota_ratio
self.stage = stage
# Instantiate language model
self.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=True)
if model_type == "t5":
self.language_model = T5ForConditionalGeneration.from_pretrained(
model_path)
self.lm_type = 'encdec'
elif model_type == "gpt2":
self.language_model = GPT2LMHeadModel.from_pretrained(model_path)
self.lm_type = 'dec'
else:
raise ValueError("type must be either seq2seq or conditional")
if self.lm_type == 'dec':
self.tokenizer.pad_token = self.tokenizer.eos_token
# Add motion tokens
self.tokenizer.add_tokens(
[f'<motion_id_{i}>' for i in range(self.m_codebook_size + 3)])
if new_token_type == "insert":
self.language_model.resize_token_embeddings(len(self.tokenizer))
elif new_token_type == "mlp":
shared = NewTokenEmb(self.language_model.shared,
self.m_codebook_size + 3)
# lm_head = NewTokenEmb(self.language_model.lm_head,
# self.m_codebook_size + 3)
self.language_model.resize_token_embeddings(len(self.tokenizer))
self.language_model.shared = shared
# self.language_model.lm_head = lm_head
# Lora
if lora:
from peft import LoraConfig, TaskType, get_peft_model, get_peft_model_state_dict
from peft.utils.other import fsdp_auto_wrap_policy
peft_config = LoraConfig(
bias="none",
task_type="CAUSAL_LM",
# inference_mode=False,
r=8,
lora_alpha=16,
lora_dropout=0.05)
self.language_model = get_peft_model(self.language_model,
peft_config)
def forward(self, texts: List[str], motion_tokens: Tensor,
lengths: List[int], tasks: dict):
if self.lm_type == 'encdec':
return self.forward_encdec(texts, motion_tokens, lengths, tasks)
elif self.lm_type == 'dec':
return self.forward_dec(texts, motion_tokens, lengths, tasks)
else:
raise NotImplementedError("Only conditional_multitask supported")
def forward_encdec(
self,
texts: List[str],
motion_tokens: Tensor,
lengths: List[int],
tasks: dict,
):
# Tensor to string
motion_strings = self.motion_token_to_string(motion_tokens, lengths)
# Supervised or unsupervised
# condition = random.choice(
# ['text', 'motion', 'supervised', 'supervised', 'supervised'])
condition = random.choice(['supervised', 'supervised', 'supervised'])
if condition == 'text':
inputs = texts
outputs = texts
elif condition == 'motion':
inputs = motion_strings
outputs = motion_strings
else:
inputs, outputs = self.template_fulfill(tasks, lengths,
motion_strings, texts)
# Tokenize
source_encoding = self.tokenizer(inputs,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt")
source_attention_mask = source_encoding.attention_mask.to(
motion_tokens.device)
source_input_ids = source_encoding.input_ids.to(motion_tokens.device)
if condition in ['text', 'motion']:
batch_size, expandend_input_length = source_input_ids.shape
mask_indices = np.asarray([
self.random_spans_noise_mask(expandend_input_length)
for i in range(batch_size)
])
target_mask = ~mask_indices
input_ids_sentinel = self.create_sentinel_ids(
mask_indices.astype(np.int8))
target_sentinel = self.create_sentinel_ids(
target_mask.astype(np.int8))
labels_input_ids = self.filter_input_ids(source_input_ids,
target_sentinel)
source_input_ids = self.filter_input_ids(source_input_ids,
input_ids_sentinel)
else:
target_inputs = self.tokenizer(outputs,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt")
labels_input_ids = target_inputs.input_ids.to(motion_tokens.device)
lables_attention_mask = target_inputs.attention_mask.to(
motion_tokens.device)
labels_input_ids[labels_input_ids == 0] = -100
outputs = self.language_model(
input_ids=source_input_ids,
attention_mask=source_attention_mask
if condition == 'supervised' else None,
labels=labels_input_ids,
decoder_attention_mask=lables_attention_mask
if condition == 'supervised' else None,
)
return outputs
def forward_dec(
self,
texts: List[str],
motion_tokens: Tensor,
lengths: List[int],
tasks: dict,
):
self.tokenizer.padding_side = "right"
# Tensor to string
motion_strings = self.motion_token_to_string(motion_tokens, lengths)
# Supervised or unsupervised
condition = random.choice(
['text', 'motion', 'supervised', 'supervised', 'supervised'])
if condition == 'text':
labels = texts
elif condition == 'motion':
labels = motion_strings
else:
inputs, outputs = self.template_fulfill(tasks, lengths,
motion_strings, texts)
labels = []
for i in range(len(inputs)):
labels.append(inputs[i] + ' \n ' + outputs[i] +
self.tokenizer.eos_token)
# Tokenize
inputs = self.tokenizer(labels,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt")
labels_input_ids = inputs.input_ids.to(motion_tokens.device)
lables_attention_mask = inputs.attention_mask.to(motion_tokens.device)
# print(labels_input_ids[0:5])
outputs = self.language_model(input_ids=labels_input_ids,
attention_mask=lables_attention_mask,
labels=inputs["input_ids"])
return outputs
def generate_direct(self,
texts: List[str],
max_length: int = 256,
num_beams: int = 1,
do_sample: bool = True,
bad_words_ids: List[int] = None):
# Device
self.device = self.language_model.device
# Tokenize
if self.lm_type == 'dec':
texts = [text + " \n " for text in texts]
source_encoding = self.tokenizer(texts,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt")
source_input_ids = source_encoding.input_ids.to(self.device)
source_attention_mask = source_encoding.attention_mask.to(self.device)
if self.lm_type == 'encdec':
outputs = self.language_model.generate(
source_input_ids,
max_length=max_length,
num_beams=num_beams,
do_sample=do_sample,
bad_words_ids=bad_words_ids,
)
elif self.lm_type == 'dec':
outputs = self.language_model.generate(
input_ids=source_input_ids,
attention_mask=source_attention_mask,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=do_sample,
max_new_tokens=max_length)
self.tokenizer.padding_side = 'left'
outputs_string = self.tokenizer.batch_decode(outputs,
skip_special_tokens=True)
print(texts[:2])
print(outputs_string[:2])
outputs_tokens, cleaned_text = self.motion_string_to_token(
outputs_string)
return outputs_tokens, cleaned_text
def generate_conditional(self,
texts: Optional[List[str]] = None,
motion_tokens: Optional[Tensor] = None,
lengths: Optional[List[int]] = None,
task: str = "t2m",
with_len: bool = False,
stage: str = 'train',
tasks: dict = None):
self.device = self.language_model.device
if task in ["t2m", "m2m", "pred", "inbetween"]:
if task == "t2m":
assert texts is not None
motion_strings = [''] * len(texts)
if not with_len:
if tasks is None:
tasks = [{
'input':
['Generate motion: <Caption_Placeholder>'],
'output': ['']
}] * len(texts)
lengths = [0] * len(texts)
else:
tasks = [{
'input': [
'Generate motion with <Frame_Placeholder> frames: <Caption_Placeholder>'
],
'output': ['']
}] * len(texts)
elif task == "pred":
assert motion_tokens is not None and lengths is not None
texts = [''] * len(lengths)
tasks = [{
'input': ['Predict motion: <Motion_Placeholder_s1>'],
'output': ['']
}] * len(lengths)
motion_strings_old = self.motion_token_to_string(
motion_tokens, lengths)
motion_strings = []
for i, length in enumerate(lengths):
split = length // 5
motion_strings.append(
'>'.join(motion_strings_old[i].split('>')[:split]) +
'>')
elif task == "inbetween":
assert motion_tokens is not None and lengths is not None
texts = [''] * len(lengths)
tasks = [{
'input': [
"Complete the masked motion: <Motion_Placeholder_Masked>"
],
'output': ['']
}] * len(lengths)
motion_strings = self.motion_token_to_string(
motion_tokens, lengths)
inputs, outputs = self.template_fulfill(tasks, lengths,
motion_strings, texts,
stage)
outputs_tokens, cleaned_text = self.generate_direct(inputs,
max_length=128,
num_beams=1,
do_sample=True)
return outputs_tokens
elif task == "m2t":
assert motion_tokens is not None and lengths is not None
motion_strings = self.motion_token_to_string(
motion_tokens, lengths)
if not with_len:
tasks = [{
'input': ['Generate text: <Motion_Placeholder>'],
'output': ['']
}] * len(lengths)
else:
tasks = [{
'input': [
'Generate text with <Frame_Placeholder> frames: <Motion_Placeholder>'
],
'output': ['']
}] * len(lengths)
texts = [''] * len(lengths)
inputs, outputs = self.template_fulfill(tasks, lengths,
motion_strings, texts)
outputs_tokens, cleaned_text = self.generate_direct(
inputs,
max_length=40,
num_beams=1,
do_sample=False,
# bad_words_ids=self.bad_words_ids
)
return cleaned_text
def motion_token_to_string(self, motion_token: Tensor, lengths: List[int]):
motion_string = []
for i in range(len(motion_token)):
motion_i = motion_token[i].cpu(
) if motion_token[i].device.type == 'cuda' else motion_token[i]
motion_list = motion_i.tolist()[:lengths[i]]
motion_string.append(
(f'<motion_id_{self.m_codebook_size}>' +
''.join([f'<motion_id_{int(i)}>' for i in motion_list]) +
f'<motion_id_{self.m_codebook_size + 1}>'))
return motion_string
def motion_token_list_to_string(self, motion_token: Tensor):
motion_string = []
for i in range(len(motion_token)):
motion_i = motion_token[i].cpu(
) if motion_token[i].device.type == 'cuda' else motion_token[i]
motion_list = motion_i.tolist()
motion_string.append(
(f'<motion_id_{self.m_codebook_size}>' +
''.join([f'<motion_id_{int(i)}>' for i in motion_list]) +
f'<motion_id_{self.m_codebook_size + 1}>'))
return motion_string
def motion_string_to_token(self, motion_string: List[str]):
motion_tokens = []
output_string = []
for i in range(len(motion_string)):
string = self.get_middle_str(
motion_string[i], f'<motion_id_{self.m_codebook_size}>',
f'<motion_id_{self.m_codebook_size + 1}>')
string_list = string.split('><')
token_list = [
int(i.split('_')[-1].replace('>', ''))
for i in string_list[1:-1]
]
if len(token_list) == 0:
token_list = [0]
token_list_padded = torch.tensor(token_list,
dtype=int).to(self.device)
motion_tokens.append(token_list_padded)
output_string.append(motion_string[i].replace(
string, '<Motion_Placeholder>'))
return motion_tokens, output_string
def placeholder_fulfill(self, prompt: str, length: int, motion_string: str,
text: str):
seconds = math.floor(length / self.framerate)
motion_splited = motion_string.split('>')
token_length = length / self.down_t
predict_head = int(token_length * self.predict_ratio + 1)
masked_head = int(token_length * self.inbetween_ratio + 1)
masked_tail = int(token_length * (1 - self.inbetween_ratio) + 1)
motion_predict_head = '>'.join(
motion_splited[:predict_head]
) + f'><motion_id_{self.m_codebook_size+1}>'
motion_predict_last = f'<motion_id_{self.m_codebook_size}>' + '>'.join(
motion_splited[predict_head:])
motion_masked = '>'.join(
motion_splited[:masked_head]
) + '>' + f'<motion_id_{self.m_codebook_size+2}>' * (
masked_tail - masked_head) + '>'.join(motion_splited[masked_tail:])
if random.random() < self.quota_ratio:
text = f'\"{text}\"'
prompt = prompt.replace('<Caption_Placeholder>', text).replace(
'<Motion_Placeholder>',
motion_string).replace('<Frame_Placeholder>', f'{length}').replace(
'<Second_Placeholder>', '%.1f' % seconds).replace(
'<Motion_Placeholder_s1>', motion_predict_head).replace(
'<Motion_Placeholder_s2>',
motion_predict_last).replace(
'<Motion_Placeholder_Masked>', motion_masked)
return prompt
def template_fulfill(self,
tasks,
lengths,
motion_strings,
texts,
stage='test'):
inputs = []
outputs = []
for i in range(len(lengths)):
input_template = random.choice(tasks[i]['input'])
output_template = random.choice(tasks[i]['output'])
length = lengths[i]
inputs.append(
self.placeholder_fulfill(input_template, length,
motion_strings[i], texts[i]))
outputs.append(
self.placeholder_fulfill(output_template, length,
motion_strings[i], texts[i]))
return inputs, outputs
def get_middle_str(self, content, startStr, endStr):
try:
startIndex = content.index(startStr)
if startIndex >= 0:
startIndex += len(startStr)
endIndex = content.index(endStr)
except:
return f'<motion_id_{self.m_codebook_size}><motion_id_0><motion_id_{self.m_codebook_size+1}>'
return f'<motion_id_{self.m_codebook_size}>' + content[
startIndex:endIndex] + f'<motion_id_{self.m_codebook_size+1}>'
def random_spans_noise_mask(self, length):
# From https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py
orig_length = length
num_noise_tokens = int(np.round(length * self.noise_density))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
num_noise_spans = int(
np.round(num_noise_tokens / self.mean_noise_span_length))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans = max(num_noise_spans, 1)
num_nonnoise_tokens = length - num_noise_tokens
# pick the lengths of the noise spans and the non-noise spans
def _random_segmentation(num_items, num_segments):
"""Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segments: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segments] containing positive integers that add
up to num_items
"""
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
np.random.shuffle(mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]])
segment_id = np.cumsum(first_in_segment)
# count length of sub segments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lengths = _random_segmentation(num_noise_tokens,
num_noise_spans)
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens,
num_noise_spans)
interleaved_span_lengths = np.reshape(
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
[num_noise_spans * 2],
)
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
span_start_indicator = np.zeros((length, ), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length]
def create_sentinel_ids(self, mask_indices):
# From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
start_indices = mask_indices - np.roll(mask_indices, 1,
axis=-1) * mask_indices
start_indices[:, 0] = mask_indices[:, 0]
sentinel_ids = np.where(start_indices != 0,
np.cumsum(start_indices, axis=-1),
start_indices)
sentinel_ids = np.where(sentinel_ids != 0,
(len(self.tokenizer) - sentinel_ids), 0)
sentinel_ids -= mask_indices - start_indices
return sentinel_ids
def filter_input_ids(self, input_ids, sentinel_ids):
# From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
batch_size = input_ids.shape[0]
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids,
input_ids.to('cpu'))
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
# masked tokens coming after sentinel tokens and should be removed
input_ids = input_ids_full[input_ids_full >= 0].reshape(
(batch_size, -1))
input_ids = np.concatenate(
[
input_ids,
np.full((batch_size, 1),
self.tokenizer.eos_token_id,
dtype=np.int32),
],
axis=-1,
)
input_ids = torch.tensor(input_ids, device=self.device)
return input_ids