import logging import torch import torch.utils.checkpoint from diffusers.models import AutoencoderKLTemporalDecoder from diffusers.schedulers import EulerDiscreteScheduler from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ..modules.unet import UNetSpatioTemporalConditionModel from ..modules.pose_net import PoseNet from ..pipelines.pipeline_mimicmotion import MimicMotionPipeline logger = logging.getLogger(__name__) class MimicMotionModel(torch.nn.Module): def __init__(self, base_model_path): """construnct base model components and load pretrained svd model except pose-net Args: base_model_path (str): pretrained svd model path """ super().__init__() self.unet = UNetSpatioTemporalConditionModel.from_config( UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet")) self.vae = AutoencoderKLTemporalDecoder.from_pretrained( base_model_path, subfolder="vae", torch_dtype=torch.float16, variant="fp16") self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( base_model_path, subfolder="image_encoder", torch_dtype=torch.float16, variant="fp16") self.noise_scheduler = EulerDiscreteScheduler.from_pretrained( base_model_path, subfolder="scheduler") self.feature_extractor = CLIPImageProcessor.from_pretrained( base_model_path, subfolder="feature_extractor") # pose_net self.pose_net = PoseNet(noise_latent_channels=self.unet.config.block_out_channels[0]) def create_pipeline(infer_config, device): """create mimicmotion pipeline and load pretrained weight Args: infer_config (str): device (str or torch.device): "cpu" or "cuda:{device_id}" """ mimicmotion_models = MimicMotionModel(infer_config.base_model_path) mimicmotion_models.load_state_dict(torch.load(infer_config.ckpt_path, map_location="cpu"), strict=False) pipeline = MimicMotionPipeline( vae=mimicmotion_models.vae, image_encoder=mimicmotion_models.image_encoder, unet=mimicmotion_models.unet, scheduler=mimicmotion_models.noise_scheduler, feature_extractor=mimicmotion_models.feature_extractor, pose_net=mimicmotion_models.pose_net ) return pipeline