|
import math |
|
import os |
|
import torch |
|
import argparse |
|
import torchvision |
|
|
|
from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, |
|
EulerDiscreteScheduler, DPMSolverMultistepScheduler, |
|
HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, |
|
DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) |
|
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler |
|
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder |
|
from omegaconf import OmegaConf |
|
from torchvision.utils import save_image |
|
from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer |
|
|
|
import os, sys |
|
|
|
from opensora.models.ae import ae_stride_config, getae, getae_wrapper |
|
from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper |
|
from opensora.models.diffusion.latte.modeling_latte import LatteT2V |
|
from opensora.models.text_encoder import get_text_enc |
|
from opensora.utils.utils import save_video_grid |
|
|
|
sys.path.append(os.path.split(sys.path[0])[0]) |
|
from pipeline_videogen import VideoGenPipeline |
|
|
|
import imageio |
|
|
|
|
|
def main(args): |
|
|
|
torch.set_grad_enabled(False) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16) |
|
if args.enable_tiling: |
|
vae.vae.enable_tiling() |
|
vae.vae.tile_overlap_factor = args.tile_overlap_factor |
|
|
|
|
|
transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) |
|
transformer_model.force_images = args.force_images |
|
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir") |
|
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) |
|
|
|
video_length, image_size = transformer_model.config.video_length, int(args.version.split('x')[1]) |
|
latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2]) |
|
vae.latent_size = latent_size |
|
if args.force_images: |
|
video_length = 1 |
|
ext = 'jpg' |
|
else: |
|
ext = 'mp4' |
|
|
|
|
|
transformer_model.eval() |
|
vae.eval() |
|
text_encoder.eval() |
|
|
|
if args.sample_method == 'DDIM': |
|
scheduler = DDIMScheduler() |
|
elif args.sample_method == 'EulerDiscrete': |
|
scheduler = EulerDiscreteScheduler() |
|
elif args.sample_method == 'DDPM': |
|
scheduler = DDPMScheduler() |
|
elif args.sample_method == 'DPMSolverMultistep': |
|
scheduler = DPMSolverMultistepScheduler() |
|
elif args.sample_method == 'DPMSolverSinglestep': |
|
scheduler = DPMSolverSinglestepScheduler() |
|
elif args.sample_method == 'PNDM': |
|
scheduler = PNDMScheduler() |
|
elif args.sample_method == 'HeunDiscrete': |
|
scheduler = HeunDiscreteScheduler() |
|
elif args.sample_method == 'EulerAncestralDiscrete': |
|
scheduler = EulerAncestralDiscreteScheduler() |
|
elif args.sample_method == 'DEISMultistep': |
|
scheduler = DEISMultistepScheduler() |
|
elif args.sample_method == 'KDPM2AncestralDiscrete': |
|
scheduler = KDPM2AncestralDiscreteScheduler() |
|
print('videogen_pipeline', device) |
|
videogen_pipeline = VideoGenPipeline(vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
scheduler=scheduler, |
|
transformer=transformer_model).to(device=device) |
|
|
|
|
|
if not os.path.exists(args.save_img_path): |
|
os.makedirs(args.save_img_path) |
|
|
|
video_grids = [] |
|
if not isinstance(args.text_prompt, list): |
|
args.text_prompt = [args.text_prompt] |
|
if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): |
|
text_prompt = open(args.text_prompt[0], 'r').readlines() |
|
args.text_prompt = [i.strip() for i in text_prompt] |
|
for prompt in args.text_prompt: |
|
print('Processing the ({}) prompt'.format(prompt)) |
|
videos = videogen_pipeline(prompt, |
|
video_length=video_length, |
|
height=image_size, |
|
width=image_size, |
|
num_inference_steps=args.num_sampling_steps, |
|
guidance_scale=args.guidance_scale, |
|
enable_temporal_attentions=not args.force_images, |
|
num_images_per_prompt=1, |
|
mask_feature=True, |
|
).video |
|
try: |
|
if args.force_images: |
|
videos = videos[:, 0].permute(0, 3, 1, 2) |
|
save_image(videos / 255.0, os.path.join(args.save_img_path, |
|
prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), |
|
nrow=1, normalize=True, value_range=(0, 1)) |
|
|
|
else: |
|
imageio.mimwrite( |
|
os.path.join( |
|
args.save_img_path, |
|
prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}' |
|
), videos[0], |
|
fps=args.fps, quality=9) |
|
except: |
|
print('Error when saving {}'.format(prompt)) |
|
video_grids.append(videos) |
|
video_grids = torch.cat(video_grids, dim=0) |
|
|
|
|
|
|
|
if args.force_images: |
|
save_image(video_grids / 255.0, os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), |
|
nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1)) |
|
else: |
|
video_grids = save_video_grid(video_grids) |
|
imageio.mimwrite(os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), video_grids, fps=args.fps, quality=9) |
|
|
|
print('save path {}'.format(args.save_img_path)) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') |
|
parser.add_argument("--version", type=str, default='65x512x512', choices=['65x512x512', '65x256x256', '17x256x256']) |
|
parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') |
|
parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') |
|
parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") |
|
parser.add_argument("--guidance_scale", type=float, default=7.5) |
|
parser.add_argument("--sample_method", type=str, default="PNDM") |
|
parser.add_argument("--num_sampling_steps", type=int, default=50) |
|
parser.add_argument("--fps", type=int, default=24) |
|
parser.add_argument("--run_time", type=int, default=0) |
|
parser.add_argument("--text_prompt", nargs='+') |
|
parser.add_argument('--force_images', action='store_true') |
|
parser.add_argument('--tile_overlap_factor', type=float, default=0.25) |
|
parser.add_argument('--enable_tiling', action='store_true') |
|
args = parser.parse_args() |
|
|
|
main(args) |