File size: 5,362 Bytes
96e9589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import argparse
import warnings
from pathlib import Path
import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler, StableDiffusionControlNetImg2ImgPipeline
from torch import Tensor
from torchvision.io.video import read_video, write_video
from torchvision.models.optical_flow import Raft_Large_Weights, raft_large
from torchvision.transforms.functional import resize
from torchvision.utils import flow_to_image
from tqdm import trange
raft_transform = Raft_Large_Weights.DEFAULT.transforms()
@torch.inference_mode()
def stylize_video(
input_video: Tensor,
prompt: str,
strength: float = 0.7,
num_steps: int = 20,
guidance_scale: float = 7.5,
controlnet_scale: float = 1.0,
batch_size: int = 4,
height: int = 512,
width: int = 512,
device: str = "cuda",
) -> Tensor:
"""
Stylize a video with temporal coherence (less flickering!) using HuggingFace's Stable Diffusion ControlNet pipeline.
Args:
input_video (Tensor): Input video tensor of shape (T, C, H, W) and range [0, 1].
prompt (str): Text prompt to condition the diffusion process.
strength (float, optional): How heavily stylization affects the image.
num_steps (int, optional): Number of diffusion steps (tradeoff between quality and speed).
guidance_scale (float, optional): Scale of the text guidance loss (how closely to adhere to text prompt).
controlnet_scale (float, optional): Scale of the ControlNet conditioning (strength of temporal coherence).
batch_size (int, optional): Number of frames to diffuse at once (faster but more memory intensive).
height (int, optional): Height of the output video.
width (int, optional): Width of the output video.
device (str, optional): Device to run stylization process on.
Returns:
Tensor: Output video tensor of shape (T, C, H, W) and range [0, 1].
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore") # silence annoying TypedStorage warnings
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=ControlNetModel.from_pretrained("wav/TemporalNet2", torch_dtype=torch.float16),
safety_checker=None,
torch_dtype=torch.float16,
).to(device)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
pipe._progress_bar_config = dict(disable=True)
raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=True).eval().to(device)
output_video = []
for i in trange(1, len(input_video), batch_size, desc="Diffusing...", unit="frame", unit_scale=batch_size):
prev = resize(input_video[i - 1 : i - 1 + batch_size], (height, width), antialias=True).to(device)
curr = resize(input_video[i : i + batch_size], (height, width), antialias=True).to(device)
prev = prev[: curr.shape[0]] # make sure prev and curr have the same batch size (for the last batch)
flow_img = flow_to_image(raft.forward(*raft_transform(prev, curr))[-1]).div(255)
control_img = torch.cat((prev, flow_img), dim=1)
output, _ = pipe(
prompt=[prompt] * curr.shape[0],
image=curr,
control_image=control_img,
height=height,
width=width,
strength=strength,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_scale,
output_type="pt",
return_dict=False,
)
output_video.append(output.permute(0, 2, 3, 1).cpu())
return torch.cat(output_video)
if __name__ == "__main__":
parser = argparse.ArgumentParser(usage=stylize_video.__doc__)
parser.add_argument("-i", "--in-file", type=str, required=True)
parser.add_argument("-p", "--prompt", type=str, required=True)
parser.add_argument("-o", "--out-file", type=str, default=None)
parser.add_argument("-s", "--strength", type=float, default=0.7)
parser.add_argument("-S", "--num-steps", type=int, default=20)
parser.add_argument("-g", "--guidance-scale", type=float, default=7.5)
parser.add_argument("-c", "--controlnet-scale", type=float, default=1.0)
parser.add_argument("-b", "--batch_size", type=int, default=4)
parser.add_argument("-H", "--height", type=int, default=512)
parser.add_argument("-W", "--width", type=int, default=512)
parser.add_argument("-d", "--device", type=str, default="cuda")
args = parser.parse_args()
input_video, _, info = read_video(args.in_file, pts_unit="sec", output_format="TCHW")
input_video = input_video.div(255)
output_video = stylize_video(
input_video=input_video,
prompt=args.prompt,
strength=args.strength,
num_steps=args.num_steps,
guidance_scale=args.guidance_scale,
controlnet_scale=args.controlnet_scale,
height=args.height,
width=args.width,
device=args.device,
batch_size=args.batch_size,
)
out_file = f"{Path(args.in_file).stem} | {args.prompt}.mp4" if args.out_file is None else args.out_file
write_video(out_file, output_video.mul(255), fps=info["video_fps"])
|