|
import os
|
|
import time
|
|
import random
|
|
import functools
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE
|
|
from hyvideo.vae import load_vae
|
|
from hyvideo.modules import load_model
|
|
from hyvideo.text_encoder import TextEncoder
|
|
from hyvideo.utils.data_utils import align_to
|
|
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
|
|
from hyvideo.modules.fp8_optimization import convert_fp8_linear
|
|
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
|
|
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
|
|
|
|
try:
|
|
import xfuser
|
|
from xfuser.core.distributed import (
|
|
get_sequence_parallel_world_size,
|
|
get_sequence_parallel_rank,
|
|
get_sp_group,
|
|
initialize_model_parallel,
|
|
init_distributed_environment
|
|
)
|
|
except:
|
|
xfuser = None
|
|
get_sequence_parallel_world_size = None
|
|
get_sequence_parallel_rank = None
|
|
get_sp_group = None
|
|
initialize_model_parallel = None
|
|
init_distributed_environment = None
|
|
|
|
|
|
def parallelize_transformer(pipe):
|
|
transformer = pipe.transformer
|
|
original_forward = transformer.forward
|
|
|
|
@functools.wraps(transformer.__class__.forward)
|
|
def new_forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
t: torch.Tensor,
|
|
text_states: torch.Tensor = None,
|
|
text_mask: torch.Tensor = None,
|
|
text_states_2: Optional[torch.Tensor] = None,
|
|
freqs_cos: Optional[torch.Tensor] = None,
|
|
freqs_sin: Optional[torch.Tensor] = None,
|
|
guidance: torch.Tensor = None,
|
|
return_dict: bool = True,
|
|
):
|
|
if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0:
|
|
|
|
split_dim = -2
|
|
elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0:
|
|
|
|
split_dim = -1
|
|
else:
|
|
raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly")
|
|
|
|
|
|
temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2
|
|
|
|
x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()]
|
|
|
|
dim_thw = freqs_cos.shape[-1]
|
|
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw)
|
|
freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
|
|
freqs_cos = freqs_cos.reshape(-1, dim_thw)
|
|
dim_thw = freqs_sin.shape[-1]
|
|
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw)
|
|
freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
|
|
freqs_sin = freqs_sin.reshape(-1, dim_thw)
|
|
|
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
|
|
|
for block in transformer.double_blocks + transformer.single_blocks:
|
|
block.hybrid_seq_parallel_attn = xFuserLongContextAttention()
|
|
|
|
output = original_forward(
|
|
x,
|
|
t,
|
|
text_states,
|
|
text_mask,
|
|
text_states_2,
|
|
freqs_cos,
|
|
freqs_sin,
|
|
guidance,
|
|
return_dict,
|
|
)
|
|
|
|
return_dict = not isinstance(output, tuple)
|
|
sample = output["x"]
|
|
sample = get_sp_group().all_gather(sample, dim=split_dim)
|
|
output["x"] = sample
|
|
return output
|
|
|
|
new_forward = new_forward.__get__(transformer)
|
|
transformer.forward = new_forward
|
|
|
|
|
|
class Inference(object):
|
|
def __init__(
|
|
self,
|
|
args,
|
|
vae,
|
|
vae_kwargs,
|
|
text_encoder,
|
|
model,
|
|
text_encoder_2=None,
|
|
pipeline=None,
|
|
use_cpu_offload=False,
|
|
device=None,
|
|
logger=None,
|
|
parallel_args=None,
|
|
):
|
|
self.vae = vae
|
|
self.vae_kwargs = vae_kwargs
|
|
|
|
self.text_encoder = text_encoder
|
|
self.text_encoder_2 = text_encoder_2
|
|
|
|
self.model = model
|
|
self.pipeline = pipeline
|
|
self.use_cpu_offload = use_cpu_offload
|
|
|
|
self.args = args
|
|
self.device = (
|
|
device
|
|
if device is not None
|
|
else "cuda"
|
|
if torch.cuda.is_available()
|
|
else "cpu"
|
|
)
|
|
self.logger = logger
|
|
self.parallel_args = parallel_args
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
|
|
"""
|
|
Initialize the Inference pipeline.
|
|
|
|
Args:
|
|
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
|
|
args (argparse.Namespace): The arguments for the pipeline.
|
|
device (int): The device for inference. Default is 0.
|
|
"""
|
|
|
|
logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
|
|
|
|
|
|
if args.ulysses_degree > 1 or args.ring_degree > 1:
|
|
assert xfuser is not None, \
|
|
"Ulysses Attention and Ring Attention requires xfuser package."
|
|
|
|
assert args.use_cpu_offload is False, \
|
|
"Cannot enable use_cpu_offload in the distributed environment."
|
|
|
|
dist.init_process_group("nccl")
|
|
|
|
assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \
|
|
"number of GPUs should be equal to ring_degree * ulysses_degree."
|
|
|
|
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
|
|
|
initialize_model_parallel(
|
|
sequence_parallel_degree=dist.get_world_size(),
|
|
ring_degree=args.ring_degree,
|
|
ulysses_degree=args.ulysses_degree,
|
|
)
|
|
device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
|
|
else:
|
|
if device is None:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
|
|
|
|
|
|
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
logger.info("Building model...")
|
|
factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
|
|
in_channels = args.latent_channels
|
|
out_channels = args.latent_channels
|
|
|
|
model = load_model(
|
|
args,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
factor_kwargs=factor_kwargs,
|
|
)
|
|
if args.use_fp8:
|
|
convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])
|
|
model = model.to(device)
|
|
model = Inference.load_state_dict(args, model, pretrained_model_path)
|
|
model.eval()
|
|
|
|
|
|
|
|
vae, _, s_ratio, t_ratio = load_vae(
|
|
args.vae,
|
|
args.vae_precision,
|
|
logger=logger,
|
|
device=device if not args.use_cpu_offload else "cpu",
|
|
)
|
|
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
|
|
|
|
|
|
if args.prompt_template_video is not None:
|
|
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get(
|
|
"crop_start", 0
|
|
)
|
|
elif args.prompt_template is not None:
|
|
crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
|
|
else:
|
|
crop_start = 0
|
|
max_length = args.text_len + crop_start
|
|
|
|
|
|
prompt_template = (
|
|
PROMPT_TEMPLATE[args.prompt_template]
|
|
if args.prompt_template is not None
|
|
else None
|
|
)
|
|
|
|
|
|
prompt_template_video = (
|
|
PROMPT_TEMPLATE[args.prompt_template_video]
|
|
if args.prompt_template_video is not None
|
|
else None
|
|
)
|
|
|
|
text_encoder = TextEncoder(
|
|
text_encoder_type=args.text_encoder,
|
|
max_length=max_length,
|
|
text_encoder_precision=args.text_encoder_precision,
|
|
tokenizer_type=args.tokenizer,
|
|
prompt_template=prompt_template,
|
|
prompt_template_video=prompt_template_video,
|
|
hidden_state_skip_layer=args.hidden_state_skip_layer,
|
|
apply_final_norm=args.apply_final_norm,
|
|
reproduce=args.reproduce,
|
|
logger=logger,
|
|
device=device if not args.use_cpu_offload else "cpu",
|
|
)
|
|
text_encoder_2 = None
|
|
if args.text_encoder_2 is not None:
|
|
text_encoder_2 = TextEncoder(
|
|
text_encoder_type=args.text_encoder_2,
|
|
max_length=args.text_len_2,
|
|
text_encoder_precision=args.text_encoder_precision_2,
|
|
tokenizer_type=args.tokenizer_2,
|
|
reproduce=args.reproduce,
|
|
logger=logger,
|
|
device=device if not args.use_cpu_offload else "cpu",
|
|
)
|
|
|
|
return cls(
|
|
args=args,
|
|
vae=vae,
|
|
vae_kwargs=vae_kwargs,
|
|
text_encoder=text_encoder,
|
|
text_encoder_2=text_encoder_2,
|
|
model=model,
|
|
use_cpu_offload=args.use_cpu_offload,
|
|
device=device,
|
|
logger=logger,
|
|
parallel_args=parallel_args
|
|
)
|
|
|
|
@staticmethod
|
|
def load_state_dict(args, model, pretrained_model_path):
|
|
load_key = args.load_key
|
|
dit_weight = Path(args.dit_weight)
|
|
|
|
if dit_weight is None:
|
|
model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
|
|
files = list(model_dir.glob("*.pt"))
|
|
if len(files) == 0:
|
|
raise ValueError(f"No model weights found in {model_dir}")
|
|
if str(files[0]).startswith("pytorch_model_"):
|
|
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
|
|
bare_model = True
|
|
elif any(str(f).endswith("_model_states.pt") for f in files):
|
|
files = [f for f in files if str(f).endswith("_model_states.pt")]
|
|
model_path = files[0]
|
|
if len(files) > 1:
|
|
logger.warning(
|
|
f"Multiple model weights found in {dit_weight}, using {model_path}"
|
|
)
|
|
bare_model = False
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid model path: {dit_weight} with unrecognized weight format: "
|
|
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
|
|
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
|
|
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
|
|
f"specific weight file, please provide the full path to the file."
|
|
)
|
|
else:
|
|
if dit_weight.is_dir():
|
|
files = list(dit_weight.glob("*.pt"))
|
|
if len(files) == 0:
|
|
raise ValueError(f"No model weights found in {dit_weight}")
|
|
if str(files[0]).startswith("pytorch_model_"):
|
|
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
|
|
bare_model = True
|
|
elif any(str(f).endswith("_model_states.pt") for f in files):
|
|
files = [f for f in files if str(f).endswith("_model_states.pt")]
|
|
model_path = files[0]
|
|
if len(files) > 1:
|
|
logger.warning(
|
|
f"Multiple model weights found in {dit_weight}, using {model_path}"
|
|
)
|
|
bare_model = False
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid model path: {dit_weight} with unrecognized weight format: "
|
|
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
|
|
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
|
|
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
|
|
f"specific weight file, please provide the full path to the file."
|
|
)
|
|
elif dit_weight.is_file():
|
|
model_path = dit_weight
|
|
bare_model = "unknown"
|
|
else:
|
|
raise ValueError(f"Invalid model path: {dit_weight}")
|
|
|
|
if not model_path.exists():
|
|
raise ValueError(f"model_path not exists: {model_path}")
|
|
logger.info(f"Loading torch model {model_path}...")
|
|
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
|
|
|
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
|
|
bare_model = False
|
|
if bare_model is False:
|
|
if load_key in state_dict:
|
|
state_dict = state_dict[load_key]
|
|
else:
|
|
raise KeyError(
|
|
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
|
|
f"are: {list(state_dict.keys())}."
|
|
)
|
|
model.load_state_dict(state_dict, strict=True)
|
|
return model
|
|
|
|
@staticmethod
|
|
def parse_size(size):
|
|
if isinstance(size, int):
|
|
size = [size]
|
|
if not isinstance(size, (list, tuple)):
|
|
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
|
|
if len(size) == 1:
|
|
size = [size[0], size[0]]
|
|
if len(size) != 2:
|
|
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
|
|
return size
|
|
|
|
|
|
class HunyuanVideoSampler(Inference):
|
|
def __init__(
|
|
self,
|
|
args,
|
|
vae,
|
|
vae_kwargs,
|
|
text_encoder,
|
|
model,
|
|
text_encoder_2=None,
|
|
pipeline=None,
|
|
use_cpu_offload=False,
|
|
device=0,
|
|
logger=None,
|
|
parallel_args=None
|
|
):
|
|
super().__init__(
|
|
args,
|
|
vae,
|
|
vae_kwargs,
|
|
text_encoder,
|
|
model,
|
|
text_encoder_2=text_encoder_2,
|
|
pipeline=pipeline,
|
|
use_cpu_offload=use_cpu_offload,
|
|
device=device,
|
|
logger=logger,
|
|
parallel_args=parallel_args
|
|
)
|
|
|
|
self.pipeline = self.load_diffusion_pipeline(
|
|
args=args,
|
|
vae=self.vae,
|
|
text_encoder=self.text_encoder,
|
|
text_encoder_2=self.text_encoder_2,
|
|
model=self.model,
|
|
device=self.device,
|
|
)
|
|
|
|
self.default_negative_prompt = NEGATIVE_PROMPT
|
|
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
|
|
parallelize_transformer(self.pipeline)
|
|
|
|
def load_diffusion_pipeline(
|
|
self,
|
|
args,
|
|
vae,
|
|
text_encoder,
|
|
text_encoder_2,
|
|
model,
|
|
scheduler=None,
|
|
device=None,
|
|
progress_bar_config=None,
|
|
data_type="video",
|
|
):
|
|
"""Load the denoising scheduler for inference."""
|
|
if scheduler is None:
|
|
if args.denoise_type == "flow":
|
|
scheduler = FlowMatchDiscreteScheduler(
|
|
shift=args.flow_shift,
|
|
reverse=args.flow_reverse,
|
|
solver=args.flow_solver,
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid denoise type {args.denoise_type}")
|
|
|
|
pipeline = HunyuanVideoPipeline(
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
text_encoder_2=text_encoder_2,
|
|
transformer=model,
|
|
scheduler=scheduler,
|
|
progress_bar_config=progress_bar_config,
|
|
args=args,
|
|
)
|
|
if self.use_cpu_offload:
|
|
pipeline.enable_sequential_cpu_offload()
|
|
else:
|
|
pipeline = pipeline.to(device)
|
|
|
|
return pipeline
|
|
|
|
def get_rotary_pos_embed(self, video_length, height, width):
|
|
target_ndim = 3
|
|
ndim = 5 - 2
|
|
|
|
if "884" in self.args.vae:
|
|
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
|
|
elif "888" in self.args.vae:
|
|
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
|
|
else:
|
|
latents_size = [video_length, height // 8, width // 8]
|
|
|
|
if isinstance(self.model.patch_size, int):
|
|
assert all(s % self.model.patch_size == 0 for s in latents_size), (
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
|
|
f"but got {latents_size}."
|
|
)
|
|
rope_sizes = [s // self.model.patch_size for s in latents_size]
|
|
elif isinstance(self.model.patch_size, list):
|
|
assert all(
|
|
s % self.model.patch_size[idx] == 0
|
|
for idx, s in enumerate(latents_size)
|
|
), (
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
|
|
f"but got {latents_size}."
|
|
)
|
|
rope_sizes = [
|
|
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
|
|
]
|
|
|
|
if len(rope_sizes) != target_ndim:
|
|
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes
|
|
head_dim = self.model.hidden_size // self.model.heads_num
|
|
rope_dim_list = self.model.rope_dim_list
|
|
if rope_dim_list is None:
|
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
|
assert (
|
|
sum(rope_dim_list) == head_dim
|
|
), "sum(rope_dim_list) should equal to head_dim of attention layer"
|
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
|
rope_dim_list,
|
|
rope_sizes,
|
|
theta=self.args.rope_theta,
|
|
use_real=True,
|
|
theta_rescale_factor=1,
|
|
)
|
|
return freqs_cos, freqs_sin
|
|
|
|
@torch.no_grad()
|
|
def predict(
|
|
self,
|
|
prompt,
|
|
height=192,
|
|
width=336,
|
|
video_length=129,
|
|
seed=None,
|
|
negative_prompt=None,
|
|
infer_steps=50,
|
|
guidance_scale=6,
|
|
flow_shift=5.0,
|
|
embedded_guidance_scale=None,
|
|
batch_size=1,
|
|
num_videos_per_prompt=1,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Predict the image/video from the given text.
|
|
|
|
Args:
|
|
prompt (str or List[str]): The input text.
|
|
kwargs:
|
|
height (int): The height of the output video. Default is 192.
|
|
width (int): The width of the output video. Default is 336.
|
|
video_length (int): The frame number of the output video. Default is 129.
|
|
seed (int or List[str]): The random seed for the generation. Default is a random integer.
|
|
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
|
|
guidance_scale (float): The guidance scale for the generation. Default is 6.0.
|
|
num_images_per_prompt (int): The number of images per prompt. Default is 1.
|
|
infer_steps (int): The number of inference steps. Default is 100.
|
|
"""
|
|
out_dict = dict()
|
|
|
|
|
|
|
|
|
|
if isinstance(seed, torch.Tensor):
|
|
seed = seed.tolist()
|
|
if seed is None:
|
|
seeds = [
|
|
random.randint(0, 1_000_000)
|
|
for _ in range(batch_size * num_videos_per_prompt)
|
|
]
|
|
elif isinstance(seed, int):
|
|
seeds = [
|
|
seed + i
|
|
for _ in range(batch_size)
|
|
for i in range(num_videos_per_prompt)
|
|
]
|
|
elif isinstance(seed, (list, tuple)):
|
|
if len(seed) == batch_size:
|
|
seeds = [
|
|
int(seed[i]) + j
|
|
for i in range(batch_size)
|
|
for j in range(num_videos_per_prompt)
|
|
]
|
|
elif len(seed) == batch_size * num_videos_per_prompt:
|
|
seeds = [int(s) for s in seed]
|
|
else:
|
|
raise ValueError(
|
|
f"Length of seed must be equal to number of prompt(batch_size) or "
|
|
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Seed must be an integer, a list of integers, or None, got {seed}."
|
|
)
|
|
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
|
|
out_dict["seeds"] = seeds
|
|
|
|
|
|
|
|
|
|
if width <= 0 or height <= 0 or video_length <= 0:
|
|
raise ValueError(
|
|
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
|
|
)
|
|
if (video_length - 1) % 4 != 0:
|
|
raise ValueError(
|
|
f"`video_length-1` must be a multiple of 4, got {video_length}"
|
|
)
|
|
|
|
logger.info(
|
|
f"Input (height, width, video_length) = ({height}, {width}, {video_length})"
|
|
)
|
|
|
|
target_height = align_to(height, 16)
|
|
target_width = align_to(width, 16)
|
|
target_video_length = video_length
|
|
|
|
out_dict["size"] = (target_height, target_width, target_video_length)
|
|
|
|
|
|
|
|
|
|
if not isinstance(prompt, str):
|
|
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
|
|
prompt = [prompt.strip()]
|
|
|
|
|
|
if negative_prompt is None or negative_prompt == "":
|
|
negative_prompt = self.default_negative_prompt
|
|
if not isinstance(negative_prompt, str):
|
|
raise TypeError(
|
|
f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
|
|
)
|
|
negative_prompt = [negative_prompt.strip()]
|
|
|
|
|
|
|
|
|
|
scheduler = FlowMatchDiscreteScheduler(
|
|
shift=flow_shift,
|
|
reverse=self.args.flow_reverse,
|
|
solver=self.args.flow_solver
|
|
)
|
|
self.pipeline.scheduler = scheduler
|
|
|
|
|
|
|
|
|
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed(
|
|
target_video_length, target_height, target_width
|
|
)
|
|
n_tokens = freqs_cos.shape[0]
|
|
|
|
|
|
|
|
|
|
debug_str = f"""
|
|
height: {target_height}
|
|
width: {target_width}
|
|
video_length: {target_video_length}
|
|
prompt: {prompt}
|
|
neg_prompt: {negative_prompt}
|
|
seed: {seed}
|
|
infer_steps: {infer_steps}
|
|
num_videos_per_prompt: {num_videos_per_prompt}
|
|
guidance_scale: {guidance_scale}
|
|
n_tokens: {n_tokens}
|
|
flow_shift: {flow_shift}
|
|
embedded_guidance_scale: {embedded_guidance_scale}"""
|
|
logger.debug(debug_str)
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
samples = self.pipeline(
|
|
prompt=prompt,
|
|
height=target_height,
|
|
width=target_width,
|
|
video_length=target_video_length,
|
|
num_inference_steps=infer_steps,
|
|
guidance_scale=guidance_scale,
|
|
negative_prompt=negative_prompt,
|
|
num_videos_per_prompt=num_videos_per_prompt,
|
|
generator=generator,
|
|
output_type="pil",
|
|
freqs_cis=(freqs_cos, freqs_sin),
|
|
n_tokens=n_tokens,
|
|
embedded_guidance_scale=embedded_guidance_scale,
|
|
data_type="video" if target_video_length > 1 else "image",
|
|
is_progress_bar=True,
|
|
vae_ver=self.args.vae,
|
|
enable_tiling=self.args.vae_tiling,
|
|
)[0]
|
|
out_dict["samples"] = samples
|
|
out_dict["prompts"] = prompt
|
|
|
|
gen_time = time.time() - start_time
|
|
logger.info(f"Success, time: {gen_time}")
|
|
|
|
return out_dict
|
|
|