fastvideogen / app.py
fantaxy's picture
Update app.py
ec476cf verified
raw
history blame
13.8 kB
import spaces
from functools import lru_cache
import gradio as gr
from gradio_toggle import Toggle
import torch
from huggingface_hub import snapshot_download
from transformers import CLIPProcessor, CLIPModel, pipeline
import random
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from xora.models.transformers.transformer3d import Transformer3DModel
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
from xora.schedulers.rf import RectifiedFlowScheduler
from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
from transformers import T5EncoderModel, T5Tokenizer
from xora.utils.conditioning_method import ConditioningMethod
from pathlib import Path
import safetensors.torch
import json
import numpy as np
import cv2
from PIL import Image
import tempfile
import os
import gc
import csv
from datetime import datetime
from openai import OpenAI
# ν•œκΈ€-μ˜μ–΄ λ²ˆμ—­κΈ° μ΄ˆκΈ°ν™”
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = False
torch.backends.cuda.preferred_blas_library="cublas"
torch.set_float32_matmul_precision("highest")
MAX_SEED = np.iinfo(np.int32).max
# Load Hugging Face token if needed
hf_token = os.getenv("HF_TOKEN")
openai_api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=openai_api_key)
system_prompt_t2v_path = "assets/system_prompt_t2v.txt"
with open(system_prompt_t2v_path, "r") as f:
system_prompt_t2v = f.read()
# Set model download directory within Hugging Face Spaces
model_path = "asset"
commit_hash='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc'
if not os.path.exists(model_path):
snapshot_download("Lightricks/LTX-Video", revision=commit_hash, local_dir=model_path, repo_type="model", token=hf_token)
# Global variables to load components
vae_dir = Path(model_path) / "vae"
unet_dir = Path(model_path) / "unet"
scheduler_dir = Path(model_path) / "scheduler"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path).to(torch.device("cuda:0"))
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
def process_prompt(prompt):
# ν•œκΈ€μ΄ ν¬ν•¨λ˜μ–΄ μžˆλŠ”μ§€ 확인
if any(ord('κ°€') <= ord(char) <= ord('힣') for char in prompt):
# ν•œκΈ€μ„ μ˜μ–΄λ‘œ λ²ˆμ—­
translated = translator(prompt)[0]['translation_text']
return translated
return prompt
def compute_clip_embedding(text=None):
inputs = clip_processor(text=text, return_tensors="pt", padding=True).to(device)
outputs = clip_model.get_text_features(**inputs)
embedding = outputs.detach().cpu().numpy().flatten().tolist()
return embedding
def load_vae(vae_dir):
vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
vae_config_path = vae_dir / "config.json"
with open(vae_config_path, "r") as f:
vae_config = json.load(f)
vae = CausalVideoAutoencoder.from_config(vae_config)
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
vae.load_state_dict(vae_state_dict)
return vae.to(device).to(torch.bfloat16)
def load_unet(unet_dir):
unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
unet_config_path = unet_dir / "config.json"
transformer_config = Transformer3DModel.load_config(unet_config_path)
transformer = Transformer3DModel.from_config(transformer_config)
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
transformer.load_state_dict(unet_state_dict, strict=True)
return transformer.to(device).to(torch.bfloat16)
def load_scheduler(scheduler_dir):
scheduler_config_path = scheduler_dir / "scheduler_config.json"
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
return RectifiedFlowScheduler.from_config(scheduler_config)
# Preset options for resolution and frame configuration
preset_options = [
{"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
{"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
{"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
{"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100},
{"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200},
{"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300},
{"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80},
{"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120},
{"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64},
{"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90},
{"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64},
{"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100},
{"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97},
{"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160},
{"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200},
]
def preset_changed(preset):
if preset != "Custom":
selected = next(item for item in preset_options if item["label"] == preset)
return (
selected["height"],
selected["width"],
selected["num_frames"],
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
else:
return (
None,
None,
None,
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
)
# Load models
vae = load_vae(vae_dir)
unet = load_unet(unet_dir)
scheduler = load_scheduler(scheduler_dir)
patchifier = SymmetricPatchifier(patch_size=1)
text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(torch.device("cuda:0"))
tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
pipeline = XoraVideoPipeline(
transformer=unet,
patchifier=patchifier,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
vae=vae,
).to(torch.device("cuda:0"))
def enhance_prompt_if_enabled(prompt, enhance_toggle):
if not enhance_toggle:
print("Enhance toggle is off, Prompt: ", prompt)
return prompt
messages = [
{"role": "system", "content": system_prompt_t2v},
{"role": "user", "content": prompt},
]
try:
response = client.chat.completions.create(
model="gpt-4-mini",
messages=messages,
max_tokens=200,
)
print("Enhanced Prompt: ", response.choices[0].message.content.strip())
return response.choices[0].message.content.strip()
except Exception as e:
print(f"Error: {e}")
return prompt
@spaces.GPU(duration=90)
def generate_video_from_text_90(
prompt="",
enhance_prompt_toggle=False,
negative_prompt="",
frame_rate=25,
seed=random.randint(0, MAX_SEED),
num_inference_steps=30,
guidance_scale=3.2,
height=768,
width=768,
num_frames=60,
progress=gr.Progress(),
):
# ν”„λ‘¬ν”„νŠΈ μ „μ²˜λ¦¬ (ν•œκΈ€ -> μ˜μ–΄)
prompt = process_prompt(prompt)
negative_prompt = process_prompt(negative_prompt)
if len(prompt.strip()) < 50:
raise gr.Error(
"Prompt must be at least 50 characters long. Please provide more details for the best results.",
duration=5,
)
prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle)
sample = {
"prompt": prompt,
"prompt_attention_mask": None,
"negative_prompt": negative_prompt,
"negative_prompt_attention_mask": None,
"media_items": None,
}
generator = torch.Generator(device="cuda").manual_seed(seed)
def gradio_progress_callback(self, step, timestep, kwargs):
progress((step + 1) / num_inference_steps)
try:
with torch.no_grad():
images = pipeline(
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
guidance_scale=guidance_scale,
generator=generator,
output_type="pt",
height=height,
width=width,
num_frames=num_frames,
frame_rate=frame_rate,
**sample,
is_video=True,
vae_per_channel_normalize=True,
conditioning_method=ConditioningMethod.UNCONDITIONAL,
mixed_precision=True,
callback_on_step_end=gradio_progress_callback,
).images
except Exception as e:
raise gr.Error(
f"An error occurred while generating the video. Please try again. Error: {e}",
duration=5,
)
finally:
torch.cuda.empty_cache()
gc.collect()
output_path = tempfile.mktemp(suffix=".mp4")
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
video_np = (video_np * 255).astype(np.uint8)
height, width = video_np.shape[1:3]
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height))
for frame in video_np[..., ::-1]:
out.write(frame)
out.release()
del images
del video_np
torch.cuda.empty_cache()
return output_path
def create_advanced_options():
with gr.Accordion("Step 4: Advanced Options (Optional)", open=False):
seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373)
inference_steps = gr.Slider(label="4.2 Inference Steps", minimum=5, maximum=150, step=5, value=40)
guidance_scale = gr.Slider(label="4.3 Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.2)
height_slider = gr.Slider(
label="4.4 Height",
minimum=256,
maximum=1024,
step=64,
value=768,
visible=False,
)
width_slider = gr.Slider(
label="4.5 Width",
minimum=256,
maximum=1024,
step=64,
value=768,
visible=False,
)
num_frames_slider = gr.Slider(
label="4.5 Number of Frames",
minimum=1,
maximum=500,
step=1,
value=60,
visible=False,
)
return [
seed,
inference_steps,
guidance_scale,
height_slider,
width_slider,
num_frames_slider,
]
with gr.Blocks(theme=gr.themes.Soft()) as iface:
with gr.Column():
txt2vid_prompt = gr.Textbox(
label="Step 1: Enter Your Prompt (ν•œκΈ€ λ˜λŠ” μ˜μ–΄)",
placeholder="μƒμ„±ν•˜κ³  싢은 λΉ„λ””μ˜€λ₯Ό μ„€λͺ…ν•˜μ„Έμš” (μ΅œμ†Œ 50자)...",
value="κΈ΄ κ°ˆμƒ‰ 머리와 밝은 ν”ΌλΆ€λ₯Ό 가진 여성이 κΈ΄ 금발 머리λ₯Ό 가진 λ‹€λ₯Έ 여성을 ν–₯ν•΄ λ―Έμ†Œ μ§“μŠ΅λ‹ˆλ‹€. κ°ˆμƒ‰ 머리 여성은 검은 μž¬ν‚·μ„ μž…κ³  있으며 였λ₯Έμͺ½ 뺨에 μž‘κ³  거의 λˆˆμ— 띄지 μ•ŠλŠ” 점이 μžˆμŠ΅λ‹ˆλ‹€. 카메라 액글은 κ°ˆμƒ‰ 머리 μ—¬μ„±μ˜ 얼꡴에 μ΄ˆμ μ„ 맞좘 ν΄λ‘œμ¦ˆμ—…μž…λ‹ˆλ‹€. μ‘°λͺ…은 λ”°λœ»ν•˜κ³  μžμ—°μŠ€λŸ¬μš°λ©°, μ•„λ§ˆλ„ μ§€λŠ” ν•΄μ—μ„œ λ‚˜μ˜€λŠ” 것 κ°™μ•„ μž₯면에 λΆ€λ“œλŸ¬μš΄ 빛을 λΉ„μΆ₯λ‹ˆλ‹€.",
lines=5,
)
txt2vid_enhance_toggle = Toggle(
label="Enhance Prompt",
value=False,
interactive=True,
)
txt2vid_negative_prompt = gr.Textbox(
label="Step 2: Enter Negative Prompt",
placeholder="λΉ„λ””μ˜€μ—μ„œ μ›ν•˜μ§€ μ•ŠλŠ” μš”μ†Œλ₯Ό μ„€λͺ…ν•˜μ„Έμš”...",
value="low quality, worst quality, deformed, distorted, damaged, motion blur, motion artifacts, fused fingers, incorrect anatomy, strange hands, ugly",
lines=2,
)
txt2vid_preset = gr.Dropdown(
choices=[p["label"] for p in preset_options],
value="512x512, 160 frames",
label="Step 3.1: Choose Resolution Preset",
)
txt2vid_frame_rate = gr.Slider(
label="Step 3.2: Frame Rate",
minimum=6,
maximum=60,
step=1,
value=20,
)
txt2vid_advanced = create_advanced_options()
txt2vid_generate = gr.Button(
"Step 5: Generate Video",
variant="primary",
size="lg",
)
txt2vid_output = gr.Video(label="Generated Output")
txt2vid_preset.change(
fn=preset_changed,
inputs=[txt2vid_preset],
outputs=txt2vid_advanced[3:],
)
txt2vid_generate.click(
fn=generate_video_from_text_90,
inputs=[
txt2vid_prompt,
txt2vid_enhance_toggle,
txt2vid_negative_prompt,
txt2vid_frame_rate,
*txt2vid_advanced,
],
outputs=txt2vid_output,
concurrency_limit=1,
concurrency_id="generate_video",
queue=True,
)
iface.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False)
# ===== Application Startup at 2024-12-20 01:30:34 =====