Spaces:
Paused
Paused
File size: 9,155 Bytes
ef16dc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import argparse
import datetime
import random
import os
import logging
from omegaconf import OmegaConf
import torch
import diffusers
from diffusers import AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel
from consisti2v.pipelines.pipeline_autoregress_animation import AutoregressiveAnimationPipeline
from consisti2v.utils.util import save_videos_grid
from diffusers.utils.import_utils import is_xformers_available
def main(args, config):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
diffusers.utils.logging.set_verbosity_info()
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"{config.output_dir}/{config.output_name}-{time_str}"
os.makedirs(savedir)
samples = []
sample_idx = 0
### >>> create validation pipeline >>> ###
if config.pipeline_pretrained_path is None:
noise_scheduler = DDIMScheduler(**OmegaConf.to_container(config.noise_scheduler_kwargs))
tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True)
text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae", use_safetensors=True)
unet = VideoLDMUNet3DConditionModel.from_pretrained(
config.pretrained_model_path,
subfolder="unet",
variant=config.unet_additional_kwargs['variant'],
temp_pos_embedding=config.unet_additional_kwargs['temp_pos_embedding'],
augment_temporal_attention=config.unet_additional_kwargs['augment_temporal_attention'],
use_temporal=True,
n_frames=config.sampling_kwargs['n_frames'],
n_temp_heads=config.unet_additional_kwargs['n_temp_heads'],
first_frame_condition_mode=config.unet_additional_kwargs['first_frame_condition_mode'],
use_frame_stride_condition=config.unet_additional_kwargs['use_frame_stride_condition'],
use_safetensors=True
)
params_unet = [p.numel() for n, p in unet.named_parameters()]
params_vae = [p.numel() for n, p in vae.named_parameters()]
params_text_encoder = [p.numel() for n, p in text_encoder.named_parameters()]
params = params_unet + params_vae + params_text_encoder
print(f"### UNet Parameters: {sum(params) / 1e6} M")
# 1. unet ckpt
if config.unet_path is not None:
if os.path.isdir(config.unet_path):
unet_dict = VideoLDMUNet3DConditionModel.from_pretrained(config.unet_path)
m, u = unet.load_state_dict(unet_dict.state_dict(), strict=False)
assert len(u) == 0
del unet_dict
else:
checkpoint_dict = torch.load(config.unet_path, map_location="cpu")
state_dict = checkpoint_dict["state_dict"] if "state_dict" in checkpoint_dict else checkpoint_dict
if config.unet_ckpt_prefix is not None:
state_dict = {k.replace(config.unet_ckpt_prefix, ''): v for k, v in state_dict.items()}
m, u = unet.load_state_dict(state_dict, strict=False)
assert len(u) == 0
if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2:
unet.enable_xformers_memory_efficient_attention()
pipeline = AutoregressiveAnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=noise_scheduler)
else:
pipeline = AutoregressiveAnimationPipeline.from_pretrained(config.pipeline_pretrained_path)
pipeline.to("cuda")
# (frameinit) initialize frequency filter for noise reinitialization -------------
if config.frameinit_kwargs.enable:
pipeline.init_filter(
width = config.sampling_kwargs.width,
height = config.sampling_kwargs.height,
video_length = config.sampling_kwargs.n_frames,
filter_params = config.frameinit_kwargs.filter_params,
)
# -------------------------------------------------------------------------------
### <<< create validation pipeline <<< ###
if args.prompt is not None:
prompts = [args.prompt]
n_prompts = [args.n_prompt]
first_frame_paths = [args.path_to_first_frame]
random_seeds = [int(args.seed)] if args.seed != "random" else "random"
else:
prompt_config = OmegaConf.load(args.prompt_config)
prompts = prompt_config.prompts
n_prompts = list(prompt_config.n_prompts) * len(prompts) if len(prompt_config.n_prompts) == 1 else prompt_config.n_prompts
first_frame_paths = prompt_config.path_to_first_frames
random_seeds = prompt_config.seeds
if random_seeds == "random":
random_seeds = [random.randint(0, 1e5) for _ in range(len(prompts))]
else:
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
config.prompt_kwargs = OmegaConf.create({"random_seeds": [], "prompts": prompts, "n_prompts": n_prompts, "first_frame_paths": first_frame_paths})
for prompt_idx, (prompt, n_prompt, first_frame_path, random_seed) in enumerate(zip(prompts, n_prompts, first_frame_paths, random_seeds)):
# manually set random seed for reproduction
if random_seed != -1: torch.manual_seed(random_seed)
else: torch.seed()
config.prompt_kwargs.random_seeds.append(torch.initial_seed())
print(f"current seed: {torch.initial_seed()}")
print(f"sampling {prompt} ...")
sample = pipeline(
prompt,
negative_prompt = n_prompt,
first_frame_paths = first_frame_path,
num_inference_steps = config.sampling_kwargs.steps,
guidance_scale_txt = config.sampling_kwargs.guidance_scale_txt,
guidance_scale_img = config.sampling_kwargs.guidance_scale_img,
width = config.sampling_kwargs.width,
height = config.sampling_kwargs.height,
video_length = config.sampling_kwargs.n_frames,
noise_sampling_method = config.unet_additional_kwargs['noise_sampling_method'],
noise_alpha = float(config.unet_additional_kwargs['noise_alpha']),
eta = config.sampling_kwargs.ddim_eta,
frame_stride = config.sampling_kwargs.frame_stride,
guidance_rescale = config.sampling_kwargs.guidance_rescale,
num_videos_per_prompt = config.sampling_kwargs.num_videos_per_prompt,
autoregress_steps = config.sampling_kwargs.autoregress_steps,
use_frameinit = config.frameinit_kwargs.enable,
frameinit_noise_level = config.frameinit_kwargs.noise_level,
).videos
samples.append(sample)
prompt = "-".join((prompt.replace("/", "").split(" ")[:10])).replace(":", "")
if sample.shape[0] > 1:
for cnt, samp in enumerate(sample):
save_videos_grid(samp.unsqueeze(0), f"{savedir}/sample/{sample_idx}-{cnt + 1}-{prompt}.{args.format}", format=args.format)
else:
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.{args.format}", format=args.format)
print(f"save to {savedir}/sample/{prompt}.{args.format}")
sample_idx += 1
samples = torch.concat(samples)
save_videos_grid(samples, f"{savedir}/sample.{args.format}", n_rows=4, format=args.format)
OmegaConf.save(config, f"{savedir}/config.yaml")
if args.save_model:
pipeline.save_pretrained(f"{savedir}/model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config", type=str, default="configs/inference/inference_autoregress.yaml")
parser.add_argument("--prompt", "-p", type=str, default=None)
parser.add_argument("--n_prompt", "-n", type=str, default="")
parser.add_argument("--seed", type=str, default="random")
parser.add_argument("--path_to_first_frame", "-f", type=str, default=None)
parser.add_argument("--prompt_config", type=str, default="configs/prompts/default.yaml")
parser.add_argument("--format", type=str, default="gif", choices=["gif", "mp4"])
parser.add_argument("--save_model", action="store_true")
parser.add_argument("optional_args", nargs='*', default=[])
args = parser.parse_args()
config = OmegaConf.load(args.inference_config)
if args.optional_args:
modified_config = OmegaConf.from_dotlist(args.optional_args)
config = OmegaConf.merge(config, modified_config)
main(args, config)
|