Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from i2v_enhance.pipeline_i2vgen_xl import I2VGenXLPipeline | |
from tqdm import tqdm | |
from PIL import Image | |
import numpy as np | |
from einops import rearrange | |
import i2v_enhance.thirdparty.VFI.config as cfg | |
from i2v_enhance.thirdparty.VFI.Trainer import Model as VFI | |
from pathlib import Path | |
from modules.params.vfi import VFIParams | |
from modules.params.i2v_enhance import I2VEnhanceParams | |
from utils.loader import download_ckpt | |
def vfi_init(ckpt_cfg: VFIParams, device_id=0): | |
cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config(F=32, depth=[ | |
2, 2, 2, 4, 4]) | |
vfi = VFI(-1) | |
ckpt_file = Path(download_ckpt( | |
local_path=ckpt_cfg.ckpt_path_local, global_path=ckpt_cfg.ckpt_path_global)) | |
vfi.load_model(ckpt_file.as_posix()) | |
vfi.eval() | |
vfi.device() | |
assert device_id == 0, "VFI on rank!=0 not implemented yet." | |
return vfi | |
def vfi_process(video, vfi, video_len): | |
video = video[:(video_len//2+1)] | |
video = [i[:, :, :3]/255. for i in video] | |
video = [i[:, :, ::-1] for i in video] | |
video = np.stack(video, axis=0) | |
video = rearrange(torch.from_numpy(video), | |
'b h w c -> b c h w').to("cuda", torch.float32) | |
frames = [] | |
for i in tqdm(range(video.shape[0]-1), desc="VFI"): | |
I0_ = video[i:i+1, ...] | |
I2_ = video[i+1:i+2, ...] | |
frames.append((I0_[0].detach().cpu().numpy().transpose( | |
1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1]) | |
mid = (vfi.inference(I0_, I2_, TTA=True, fast_TTA=True)[ | |
0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8) | |
frames.append(mid[:, :, ::-1]) | |
frames.append((video[-1].detach().cpu().numpy().transpose(1, | |
2, 0) * 255.0).astype(np.uint8)[:, :, ::-1]) | |
if video_len % 2 == 0: | |
frames.append((video[-1].detach().cpu().numpy().transpose(1, | |
2, 0) * 255.0).astype(np.uint8)[:, :, ::-1]) | |
del vfi | |
del video | |
torch.cuda.empty_cache() | |
video = [Image.fromarray(frame).resize((1280, 720)) for frame in frames] | |
del frames | |
return video | |
def i2v_enhance_init(i2vgen_cfg: I2VEnhanceParams): | |
generator = torch.manual_seed(8888) | |
try: | |
pipeline = I2VGenXLPipeline.from_pretrained( | |
i2vgen_cfg.ckpt_path_local, torch_dtype=torch.float16, variant="fp16") | |
except Exception as e: | |
pipeline = I2VGenXLPipeline.from_pretrained( | |
i2vgen_cfg.ckpt_path_global, torch_dtype=torch.float16, variant="fp16") | |
pipeline.save_pretrained(i2vgen_cfg.ckpt_path_local) | |
pipeline.enable_model_cpu_offload() | |
return pipeline, generator | |
def i2v_enhance_process(image, video, pipeline, generator, overlap_size, strength, chunk_size=38, use_randomized_blending=False): | |
prompt = "High Quality, HQ, detailed." | |
negative_prompt = "Distorted, blurry, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms" | |
if use_randomized_blending: | |
# We first need to enhance key-frames (the 1st frame of each chunk) | |
video_chunks = [video[i:i+chunk_size] for i in range(0, len( | |
video), chunk_size-overlap_size) if len(video[i:i+chunk_size]) == chunk_size] | |
video_short = [chunk[0] for chunk in video_chunks] | |
# If randomized blending then we must have a list of starting images (1 for each chunk) | |
image = pipeline( | |
prompt=prompt, | |
height=720, | |
width=1280, | |
image=image, | |
video=video_short, | |
strength=strength, | |
overlap_size=0, | |
chunk_size=len(video_short), | |
num_frames=len(video_short), | |
num_inference_steps=30, | |
decode_chunk_size=1, | |
negative_prompt=negative_prompt, | |
guidance_scale=9.0, | |
generator=generator, | |
).frames[0] | |
# Remove the last few frames (< chunk_size) of the video that do not fit into one chunk. | |
max_idx = (chunk_size - overlap_size) * \ | |
(len(video_chunks) - 1) + chunk_size | |
video = video[:max_idx] | |
frames = pipeline( | |
prompt=prompt, | |
height=720, | |
width=1280, | |
image=image, | |
video=video, | |
strength=strength, | |
overlap_size=overlap_size, | |
chunk_size=chunk_size, | |
num_frames=chunk_size, | |
num_inference_steps=30, | |
decode_chunk_size=1, | |
negative_prompt=negative_prompt, | |
guidance_scale=9.0, | |
generator=generator, | |
).frames[0] | |
return frames | |