mingyuan's picture
initial commit
a0d91d3
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
from ..builder import SUBMODULES
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp32"""
def _convert_weights_to_fp32(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.float()
if l.bias is not None:
l.bias.data = l.bias.data.float()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.float()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.float()
model.apply(_convert_weights_to_fp32)
@SUBMODULES.register_module()
class MDMTransformer(nn.Module):
def __init__(self,
input_feats=263,
latent_dim=256,
ff_size=1024,
num_layers=8,
num_heads=4,
dropout=0.1,
activation="gelu",
clip_dim=512,
clip_version=None,
guide_scale=1.0,
cond_mask_prob=0.1,
use_official_ckpt=False,
**kwargs):
super().__init__()
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.activation = activation
self.clip_dim = clip_dim
self.input_feats = input_feats
self.guide_scale = guide_scale
self.use_official_ckpt = use_official_ckpt
self.cond_mask_prob = cond_mask_prob
self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=self.activation)
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
num_layers=self.num_layers)
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
self.embed_text = nn.Linear(self.clip_dim, self.latent_dim)
self.clip_version = clip_version
self.clip_model = self.load_and_freeze_clip(clip_version)
self.poseFinal = nn.Linear(self.latent_dim, self.input_feats)
def load_and_freeze_clip(self, clip_version):
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
jit=False) # Must set jit=False for training
clip.model.convert_weights(
clip_model) # Actually this line is unnecessary since clip by default already on float16
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
return clip_model
def mask_cond(self, cond, force_mask=False):
bs, d = cond.shape
if force_mask:
return torch.zeros_like(cond)
elif self.training and self.cond_mask_prob > 0.:
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond
return cond * (1. - mask)
else:
return cond
def encode_text(self, raw_text):
# raw_text - list (batch_size length) of strings with input text prompts
device = next(self.parameters()).device
max_text_len = 20
if max_text_len is not None:
default_context_length = 77
context_length = max_text_len + 2 # start_token + 20 + end_token
assert context_length < default_context_length
texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device)
zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device)
texts = torch.cat([texts, zero_pad], dim=1)
return self.clip_model.encode_text(texts).float()
def get_precompute_condition(self, text, device=None, **kwargs):
if not self.training and device == torch.device('cpu'):
convert_weights(self.clip_model)
text_feat = self.encode_text(text)
return {'text_feat': text_feat}
def post_process(self, motion):
assert len(motion.shape) == 3
if self.use_official_ckpt:
motion[:, :, :4] = motion[:, :, :4] * 25
return motion
def forward(self, motion, timesteps, text_feat=None, **kwargs):
"""
motion: B, T, D
timesteps: [batch_size] (int)
"""
B, T, D = motion.shape
device = motion.device
if text_feat is None:
enc_text = get_precompute_condition(**kwargs)['text_feat']
else:
enc_text = text_feat
if self.training:
# T, B, D
motion = self.poseEmbedding(motion).permute(1, 0, 2)
emb = self.embed_timestep(timesteps) # [1, bs, d]
emb += self.embed_text(self.mask_cond(enc_text, force_mask=False))
xseq = self.sequence_pos_encoder(torch.cat((emb, motion), axis=0))
output = self.seqTransEncoder(xseq)[1:]
# B, T, D
output = self.poseFinal(output).permute(1, 0, 2)
return output
else:
# T, B, D
motion = self.poseEmbedding(motion).permute(1, 0, 2)
emb = self.embed_timestep(timesteps) # [1, bs, d]
emb_uncond = emb + self.embed_text(self.mask_cond(enc_text, force_mask=True))
emb_text = emb + self.embed_text(self.mask_cond(enc_text, force_mask=False))
xseq = self.sequence_pos_encoder(torch.cat((emb_uncond, motion), axis=0))
xseq_text = self.sequence_pos_encoder(torch.cat((emb_text, motion), axis=0))
output = self.seqTransEncoder(xseq)[1:]
output_text = self.seqTransEncoder(xseq_text)[1:]
# B, T, D
output = self.poseFinal(output).permute(1, 0, 2)
output_text = self.poseFinal(output_text).permute(1, 0, 2)
scale = self.guide_scale
output = output + scale * (output_text - output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# not used in the final model
x = x + self.pe[:x.shape[0], :]
return self.dropout(x)
class TimestepEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)