vilarin's picture
Update app.py
27a03f9 verified
raw
history blame
4.15 kB
import os
import gradio as gr
import torch
import spaces
import random
from PIL import Image
from glob import glob
from pathlib import Path
from typing import Optional
from diffsynth import ModelManager, SVDVideoPipeline, HunyuanDiTImagePipeline
from diffsynth import ModelManager
from diffusers.utils import load_image, export_to_video
import uuid
from huggingface_hub import hf_hub_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Constants
#MAX_SEED = np.iinfo(np.int32).max
MAX_SEED = 2147483647
CSS = """
footer {
visibility: hidden;
}
"""
JS = """function () {
gradioURL = window.location.href
if (!gradioURL.endsWith('?__theme=dark')) {
window.location.replace(gradioURL + '?__theme=dark');
}
}"""
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
model_manager = ModelManager(
torch_dtype=torch.float16,
device="cuda",
model_id_list=["stable-video-diffusion-img2vid-xt", "ExVideo-SVD-128f-v1"],
downloading_priority=["HuggingFace"])
pipe = SVDVideoPipeline.from_model_manager(model_manager)
# function source codes modified from multimodalart/stable-video-diffusion
@spaces.GPU(duration=120)
def generate(
image: image,
seed: Optional[int] = -1,
motion_bucket_id: int = 127,
fps_id: int = 25,
output_folder: str = "outputs",
progress=gr.Progress(track_tqdm=True)):
if seed == -1:
seed = random.randint(0, MAX_SEED)
image = Image.open(image)
if image.mode == "RGBA":
image = image.convert("RGB")
torch.manual_seed(seed)
os.makedirs(output_folder, exist_ok=True)
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
frames = pipe(
input_image=image.resize((512, 512)),
num_frames=128,
fps=fps_id,
height=512,
width=512,
motion_bucket_id=motion_bucket_id,
num_inference_steps=50,
min_cfg_scale=2,
max_cfg_scale=2,
contrast_enhance_scale=1.2
).frames[0]
export_to_video(frames, video_path, fps=fps_id)
return video_path, seed
examples = [
"./train.jpg",
"./girl.webp",
"./robo.jpg",
]
# Gradio Interface
with gr.Blocks(css=CSS, js=JS, theme="soft") as demo:
gr.HTML("<h1><center>Exvideo📽️</center></h1>")
gr.HTML("<p><center><a href='https://huggingface.co./ECNU-CILab/ExVideo-SVD-128f-v1'>ExVideo</a> image-to-video generation<br><b>Update</b>: first version</center></p>")
with gr.Row():
image = gr.Image(label='Upload Image', height=600, scale=2)
video = gr.Video(label="Generated Video", height=600, scale=2)
with gr.Accordion("Advanced Options", open=True):
with gr.Column(scale=1):
seed = gr.Slider(
label="Seed (-1 Random)",
minimum=-1,
maximum=MAX_SEED,
step=1,
value=-1,
)
motion_bucket_id = gr.Slider(
label="Motion bucket id",
info="Controls how much motion to add/remove from the image",
value=127,
minimum=1,
maximum=255
)
fps_id = gr.Slider(
label="Frames per second",
info="The length of your video in seconds will be 25/fps",
value=25,
minimum=5,
maximum=30
)
submit_btn = gr.Button(value="Generate")
clear_btn = gr.ClearButton([image, seed, video])
gr.Examples(
examples=examples,
inputs=image,
outputs=[video, seed],
fn=generate,
cache_examples="lazy",
examples_per_page=4,
)
submit_btn.click(fn=generate, inputs=[image, seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
demo.queue().launch()