Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from functools import lru_cache | |
import gradio as gr | |
from gradio_toggle import Toggle | |
import torch | |
from huggingface_hub import snapshot_download | |
from transformers import CLIPProcessor, CLIPModel, pipeline | |
import random | |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
from xora.models.transformers.transformer3d import Transformer3DModel | |
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
from xora.schedulers.rf import RectifiedFlowScheduler | |
from xora.pipelines.pipeline_xora_video import XoraVideoPipeline | |
from transformers import T5EncoderModel, T5Tokenizer | |
from xora.utils.conditioning_method import ConditioningMethod | |
from pathlib import Path | |
import safetensors.torch | |
import json | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import tempfile | |
import os | |
import gc | |
import csv | |
from datetime import datetime | |
from openai import OpenAI | |
# νκΈ-μμ΄ λ²μκΈ° μ΄κΈ°ν | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
torch.backends.cudnn.allow_tf32 = False | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cuda.preferred_blas_library="cublas" | |
torch.set_float32_matmul_precision("highest") | |
MAX_SEED = np.iinfo(np.int32).max | |
# Load Hugging Face token if needed | |
hf_token = os.getenv("HF_TOKEN") | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
client = OpenAI(api_key=openai_api_key) | |
system_prompt_t2v_path = "assets/system_prompt_t2v.txt" | |
with open(system_prompt_t2v_path, "r") as f: | |
system_prompt_t2v = f.read() | |
# Set model download directory within Hugging Face Spaces | |
model_path = "asset" | |
commit_hash='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc' | |
if not os.path.exists(model_path): | |
snapshot_download("Lightricks/LTX-Video", revision=commit_hash, local_dir=model_path, repo_type="model", token=hf_token) | |
# Global variables to load components | |
vae_dir = Path(model_path) / "vae" | |
unet_dir = Path(model_path) / "unet" | |
scheduler_dir = Path(model_path) / "scheduler" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path).to(torch.device("cuda:0")) | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path) | |
def process_prompt(prompt): | |
# νκΈμ΄ ν¬ν¨λμ΄ μλμ§ νμΈ | |
if any(ord('κ°') <= ord(char) <= ord('ν£') for char in prompt): | |
# νκΈμ μμ΄λ‘ λ²μ | |
translated = translator(prompt)[0]['translation_text'] | |
return translated | |
return prompt | |
def compute_clip_embedding(text=None): | |
inputs = clip_processor(text=text, return_tensors="pt", padding=True).to(device) | |
outputs = clip_model.get_text_features(**inputs) | |
embedding = outputs.detach().cpu().numpy().flatten().tolist() | |
return embedding | |
def load_vae(vae_dir): | |
vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors" | |
vae_config_path = vae_dir / "config.json" | |
with open(vae_config_path, "r") as f: | |
vae_config = json.load(f) | |
vae = CausalVideoAutoencoder.from_config(vae_config) | |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path) | |
vae.load_state_dict(vae_state_dict) | |
return vae.to(device).to(torch.bfloat16) | |
def load_unet(unet_dir): | |
unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors" | |
unet_config_path = unet_dir / "config.json" | |
transformer_config = Transformer3DModel.load_config(unet_config_path) | |
transformer = Transformer3DModel.from_config(transformer_config) | |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path) | |
transformer.load_state_dict(unet_state_dict, strict=True) | |
return transformer.to(device).to(torch.bfloat16) | |
def load_scheduler(scheduler_dir): | |
scheduler_config_path = scheduler_dir / "scheduler_config.json" | |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) | |
return RectifiedFlowScheduler.from_config(scheduler_config) | |
# Preset options for resolution and frame configuration | |
preset_options = [ | |
{"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41}, | |
{"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49}, | |
{"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57}, | |
{"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100}, | |
{"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200}, | |
{"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300}, | |
{"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80}, | |
{"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120}, | |
{"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64}, | |
{"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90}, | |
{"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64}, | |
{"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100}, | |
{"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97}, | |
{"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160}, | |
{"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200}, | |
] | |
def preset_changed(preset): | |
if preset != "Custom": | |
selected = next(item for item in preset_options if item["label"] == preset) | |
return ( | |
selected["height"], | |
selected["width"], | |
selected["num_frames"], | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
else: | |
return ( | |
None, | |
None, | |
None, | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
) | |
# Load models | |
vae = load_vae(vae_dir) | |
unet = load_unet(unet_dir) | |
scheduler = load_scheduler(scheduler_dir) | |
patchifier = SymmetricPatchifier(patch_size=1) | |
text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(torch.device("cuda:0")) | |
tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer") | |
pipeline = XoraVideoPipeline( | |
transformer=unet, | |
patchifier=patchifier, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
vae=vae, | |
).to(torch.device("cuda:0")) | |
def enhance_prompt_if_enabled(prompt, enhance_toggle): | |
if not enhance_toggle: | |
print("Enhance toggle is off, Prompt: ", prompt) | |
return prompt | |
messages = [ | |
{"role": "system", "content": system_prompt_t2v}, | |
{"role": "user", "content": prompt}, | |
] | |
try: | |
response = client.chat.completions.create( | |
model="gpt-4-mini", | |
messages=messages, | |
max_tokens=200, | |
) | |
print("Enhanced Prompt: ", response.choices[0].message.content.strip()) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
print(f"Error: {e}") | |
return prompt | |
def generate_video_from_text_90( | |
prompt="", | |
enhance_prompt_toggle=False, | |
negative_prompt="", | |
frame_rate=25, | |
seed=random.randint(0, MAX_SEED), | |
num_inference_steps=30, | |
guidance_scale=3.2, | |
height=768, | |
width=768, | |
num_frames=60, | |
progress=gr.Progress(), | |
): | |
# ν둬ννΈ μ μ²λ¦¬ (νκΈ -> μμ΄) | |
prompt = process_prompt(prompt) | |
negative_prompt = process_prompt(negative_prompt) | |
if len(prompt.strip()) < 50: | |
raise gr.Error( | |
"Prompt must be at least 50 characters long. Please provide more details for the best results.", | |
duration=5, | |
) | |
prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle) | |
sample = { | |
"prompt": prompt, | |
"prompt_attention_mask": None, | |
"negative_prompt": negative_prompt, | |
"negative_prompt_attention_mask": None, | |
"media_items": None, | |
} | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
def gradio_progress_callback(self, step, timestep, kwargs): | |
progress((step + 1) / num_inference_steps) | |
try: | |
with torch.no_grad(): | |
images = pipeline( | |
num_inference_steps=num_inference_steps, | |
num_images_per_prompt=1, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
output_type="pt", | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
frame_rate=frame_rate, | |
**sample, | |
is_video=True, | |
vae_per_channel_normalize=True, | |
conditioning_method=ConditioningMethod.UNCONDITIONAL, | |
mixed_precision=True, | |
callback_on_step_end=gradio_progress_callback, | |
).images | |
except Exception as e: | |
raise gr.Error( | |
f"An error occurred while generating the video. Please try again. Error: {e}", | |
duration=5, | |
) | |
finally: | |
torch.cuda.empty_cache() | |
gc.collect() | |
output_path = tempfile.mktemp(suffix=".mp4") | |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy() | |
video_np = (video_np * 255).astype(np.uint8) | |
height, width = video_np.shape[1:3] | |
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height)) | |
for frame in video_np[..., ::-1]: | |
out.write(frame) | |
out.release() | |
del images | |
del video_np | |
torch.cuda.empty_cache() | |
return output_path | |
def create_advanced_options(): | |
with gr.Accordion("Step 4: Advanced Options (Optional)", open=False): | |
seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373) | |
inference_steps = gr.Slider(label="4.2 Inference Steps", minimum=5, maximum=150, step=5, value=40) | |
guidance_scale = gr.Slider(label="4.3 Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.2) | |
height_slider = gr.Slider( | |
label="4.4 Height", | |
minimum=256, | |
maximum=1024, | |
step=64, | |
value=768, | |
visible=False, | |
) | |
width_slider = gr.Slider( | |
label="4.5 Width", | |
minimum=256, | |
maximum=1024, | |
step=64, | |
value=768, | |
visible=False, | |
) | |
num_frames_slider = gr.Slider( | |
label="4.5 Number of Frames", | |
minimum=1, | |
maximum=500, | |
step=1, | |
value=60, | |
visible=False, | |
) | |
return [ | |
seed, | |
inference_steps, | |
guidance_scale, | |
height_slider, | |
width_slider, | |
num_frames_slider, | |
] | |
with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
with gr.Column(): | |
txt2vid_prompt = gr.Textbox( | |
label="Step 1: Enter Your Prompt (νκΈ λλ μμ΄)", | |
placeholder="μμ±νκ³ μΆμ λΉλμ€λ₯Ό μ€λͺ νμΈμ (μ΅μ 50μ)...", | |
value="κΈ΄ κ°μ 머리μ λ°μ νΌλΆλ₯Ό κ°μ§ μ¬μ±μ΄ κΈ΄ κΈλ° 머리λ₯Ό κ°μ§ λ€λ₯Έ μ¬μ±μ ν₯ν΄ λ―Έμ μ§μ΅λλ€. κ°μ 머리 μ¬μ±μ κ²μ μ¬ν·μ μ κ³ μμΌλ©° μ€λ₯Έμͺ½ λΊ¨μ μκ³ κ±°μ λμ λμ§ μλ μ μ΄ μμ΅λλ€. μΉ΄λ©λΌ μ΅κΈμ κ°μ 머리 μ¬μ±μ μΌκ΅΄μ μ΄μ μ λ§μΆ ν΄λ‘μ¦μ μ λλ€. μ‘°λͺ μ λ°λ»νκ³ μμ°μ€λ¬μ°λ©°, μλ§λ μ§λ ν΄μμ λμ€λ κ² κ°μ μ₯λ©΄μ λΆλλ¬μ΄ λΉμ λΉμΆ₯λλ€.", | |
lines=5, | |
) | |
txt2vid_enhance_toggle = Toggle( | |
label="Enhance Prompt", | |
value=False, | |
interactive=True, | |
) | |
txt2vid_negative_prompt = gr.Textbox( | |
label="Step 2: Enter Negative Prompt", | |
placeholder="λΉλμ€μμ μνμ§ μλ μμλ₯Ό μ€λͺ νμΈμ...", | |
value="low quality, worst quality, deformed, distorted, damaged, motion blur, motion artifacts, fused fingers, incorrect anatomy, strange hands, ugly", | |
lines=2, | |
) | |
txt2vid_preset = gr.Dropdown( | |
choices=[p["label"] for p in preset_options], | |
value="512x512, 160 frames", | |
label="Step 3.1: Choose Resolution Preset", | |
) | |
txt2vid_frame_rate = gr.Slider( | |
label="Step 3.2: Frame Rate", | |
minimum=6, | |
maximum=60, | |
step=1, | |
value=20, | |
) | |
txt2vid_advanced = create_advanced_options() | |
txt2vid_generate = gr.Button( | |
"Step 5: Generate Video", | |
variant="primary", | |
size="lg", | |
) | |
txt2vid_output = gr.Video(label="Generated Output") | |
txt2vid_preset.change( | |
fn=preset_changed, | |
inputs=[txt2vid_preset], | |
outputs=txt2vid_advanced[3:], | |
) | |
txt2vid_generate.click( | |
fn=generate_video_from_text_90, | |
inputs=[ | |
txt2vid_prompt, | |
txt2vid_enhance_toggle, | |
txt2vid_negative_prompt, | |
txt2vid_frame_rate, | |
*txt2vid_advanced, | |
], | |
outputs=txt2vid_output, | |
concurrency_limit=1, | |
concurrency_id="generate_video", | |
queue=True, | |
) | |
iface.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False) | |
# ===== Application Startup at 2024-12-20 01:30:34 ===== |