Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import clip | |
from ..builder import SUBMODULES | |
def convert_weights(model: nn.Module): | |
"""Convert applicable model parameters to fp32""" | |
def _convert_weights_to_fp32(l): | |
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |
l.weight.data = l.weight.data.float() | |
if l.bias is not None: | |
l.bias.data = l.bias.data.float() | |
if isinstance(l, nn.MultiheadAttention): | |
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: | |
tensor = getattr(l, attr) | |
if tensor is not None: | |
tensor.data = tensor.data.float() | |
for name in ["text_projection", "proj"]: | |
if hasattr(l, name): | |
attr = getattr(l, name) | |
if attr is not None: | |
attr.data = attr.data.float() | |
model.apply(_convert_weights_to_fp32) | |
class MDMTransformer(nn.Module): | |
def __init__(self, | |
input_feats=263, | |
latent_dim=256, | |
ff_size=1024, | |
num_layers=8, | |
num_heads=4, | |
dropout=0.1, | |
activation="gelu", | |
clip_dim=512, | |
clip_version=None, | |
guide_scale=1.0, | |
cond_mask_prob=0.1, | |
use_official_ckpt=False, | |
**kwargs): | |
super().__init__() | |
self.latent_dim = latent_dim | |
self.ff_size = ff_size | |
self.num_layers = num_layers | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.activation = activation | |
self.clip_dim = clip_dim | |
self.input_feats = input_feats | |
self.guide_scale = guide_scale | |
self.use_official_ckpt = use_official_ckpt | |
self.cond_mask_prob = cond_mask_prob | |
self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) | |
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) | |
self.embed_text = nn.Linear(self.clip_dim, self.latent_dim) | |
self.clip_version = clip_version | |
self.clip_model = self.load_and_freeze_clip(clip_version) | |
self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) | |
def load_and_freeze_clip(self, clip_version): | |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu', | |
jit=False) # Must set jit=False for training | |
clip.model.convert_weights( | |
clip_model) # Actually this line is unnecessary since clip by default already on float16 | |
clip_model.eval() | |
for p in clip_model.parameters(): | |
p.requires_grad = False | |
return clip_model | |
def mask_cond(self, cond, force_mask=False): | |
bs, d = cond.shape | |
if force_mask: | |
return torch.zeros_like(cond) | |
elif self.training and self.cond_mask_prob > 0.: | |
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond | |
return cond * (1. - mask) | |
else: | |
return cond | |
def encode_text(self, raw_text): | |
# raw_text - list (batch_size length) of strings with input text prompts | |
device = next(self.parameters()).device | |
max_text_len = 20 | |
if max_text_len is not None: | |
default_context_length = 77 | |
context_length = max_text_len + 2 # start_token + 20 + end_token | |
assert context_length < default_context_length | |
texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) | |
zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) | |
texts = torch.cat([texts, zero_pad], dim=1) | |
return self.clip_model.encode_text(texts).float() | |
def get_precompute_condition(self, text, device=None, **kwargs): | |
if not self.training and device == torch.device('cpu'): | |
convert_weights(self.clip_model) | |
text_feat = self.encode_text(text) | |
return {'text_feat': text_feat} | |
def post_process(self, motion): | |
assert len(motion.shape) == 3 | |
if self.use_official_ckpt: | |
motion[:, :, :4] = motion[:, :, :4] * 25 | |
return motion | |
def forward(self, motion, timesteps, text_feat=None, **kwargs): | |
""" | |
motion: B, T, D | |
timesteps: [batch_size] (int) | |
""" | |
B, T, D = motion.shape | |
device = motion.device | |
if text_feat is None: | |
enc_text = get_precompute_condition(**kwargs)['text_feat'] | |
else: | |
enc_text = text_feat | |
if self.training: | |
# T, B, D | |
motion = self.poseEmbedding(motion).permute(1, 0, 2) | |
emb = self.embed_timestep(timesteps) # [1, bs, d] | |
emb += self.embed_text(self.mask_cond(enc_text, force_mask=False)) | |
xseq = self.sequence_pos_encoder(torch.cat((emb, motion), axis=0)) | |
output = self.seqTransEncoder(xseq)[1:] | |
# B, T, D | |
output = self.poseFinal(output).permute(1, 0, 2) | |
return output | |
else: | |
# T, B, D | |
motion = self.poseEmbedding(motion).permute(1, 0, 2) | |
emb = self.embed_timestep(timesteps) # [1, bs, d] | |
emb_uncond = emb + self.embed_text(self.mask_cond(enc_text, force_mask=True)) | |
emb_text = emb + self.embed_text(self.mask_cond(enc_text, force_mask=False)) | |
xseq = self.sequence_pos_encoder(torch.cat((emb_uncond, motion), axis=0)) | |
xseq_text = self.sequence_pos_encoder(torch.cat((emb_text, motion), axis=0)) | |
output = self.seqTransEncoder(xseq)[1:] | |
output_text = self.seqTransEncoder(xseq_text)[1:] | |
# B, T, D | |
output = self.poseFinal(output).permute(1, 0, 2) | |
output_text = self.poseFinal(output_text).permute(1, 0, 2) | |
scale = self.guide_scale | |
output = output + scale * (output_text - output) | |
return output | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
# not used in the final model | |
x = x + self.pe[:x.shape[0], :] | |
return self.dropout(x) | |
class TimestepEmbedder(nn.Module): | |
def __init__(self, latent_dim, sequence_pos_encoder): | |
super().__init__() | |
self.latent_dim = latent_dim | |
self.sequence_pos_encoder = sequence_pos_encoder | |
time_embed_dim = self.latent_dim | |
self.time_embed = nn.Sequential( | |
nn.Linear(self.latent_dim, time_embed_dim), | |
nn.SiLU(), | |
nn.Linear(time_embed_dim, time_embed_dim), | |
) | |
def forward(self, timesteps): | |
return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) | |