Spaces:
Running
on
L40S
Running
on
L40S
File size: 6,532 Bytes
10b581c ef5add3 9dab6c2 10b581c 9dab6c2 10b581c 70f2266 10b581c 5694315 10b581c 5694315 9dab6c2 5694315 9dab6c2 5694315 ef5add3 51df367 ef5add3 5694315 9dab6c2 45a9d7f 51df367 dfaa5fc cc5ea83 51df367 df253f2 dfe3b1e 51df367 1b6ca43 10b581c 5694315 51df367 df253f2 51df367 5694315 51df367 10b581c ef5add3 10b581c 5694315 66d038e 10b581c 6a76f54 10b581c 192f60f 10b581c ecf6d80 6a76f54 e65bce3 6a76f54 10b581c ecf6d80 10b581c df253f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import gradio as gr
import os
import torch
import gc
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import T5EncoderModel, T5Tokenizer
from datetime import datetime
import random
from huggingface_hub import hf_hub_download
# Ensure 'checkpoint' directory exists
os.makedirs("checkpoints", exist_ok=True)
# Download LoRA weights
hf_hub_download(
repo_id="wenqsun/DimensionX",
filename="orbit_left_lora_weights.safetensors",
local_dir="checkpoints"
)
hf_hub_download(
repo_id="wenqsun/DimensionX",
filename="orbit_up_lora_weights.safetensors",
local_dir="checkpoints"
)
# Load models in the global scope
model_id = "THUDM/CogVideoX-5b-I2V"
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16).to("cpu")
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float16).to("cpu")
vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16).to("cpu")
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16)
# Add this near the top after imports
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
def find_and_move_object_to_cpu():
for obj in gc.get_objects():
try:
if isinstance(obj, torch.nn.Module):
if any(param.is_cuda for param in obj.parameters()):
obj.to('cpu')
if any(buf.is_cuda for buf in obj.buffers()):
obj.to('cpu')
except Exception as e:
pass
def clear_gpu():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
# Move everything to CPU initially
pipe.to("cpu")
torch.cuda.empty_cache()
lora_path = "checkpoints/"
weight_name = "orbit_left_lora_weights.safetensors" if orbit_type == "Left" else "orbit_up_lora_weights.safetensors"
lora_rank = 256
adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Load LoRA weights on CPU
pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"adapter_{adapter_timestamp}")
pipe.fuse_lora(lora_scale=1 / lora_rank)
try:
# Move to GPU just before inference
pipe.to("cuda")
torch.cuda.empty_cache()
prompt = f"{prompt}. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
image = load_image(image_path)
seed = random.randint(0, 2**8 - 1)
with torch.inference_mode():
video = pipe(
image,
prompt,
num_inference_steps=50,
guidance_scale=7.0,
use_dynamic_cfg=True,
generator=torch.Generator(device="cpu").manual_seed(seed)
)
finally:
# Ensure cleanup happens even if inference fails
pipe.to("cpu")
pipe.unfuse_lora()
pipe.unload_lora_weights()
torch.cuda.empty_cache()
gc.collect()
# Generate output video
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
export_to_video(video.frames[0], f"output_{timestamp}.mp4", fps=8)
return f"output_{timestamp}.mp4"
# Set up Gradio UI
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# DimensionX")
gr.Markdown("### Create Any 3D and 4D Scenes from a Single Image with Controllable Video Diffusion")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/wenqsun/DimensionX">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://chenshuo20.github.io/DimensionX/">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://arxiv.org/abs/2411.04928">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="https://huggingface.co./spaces/fffiloni/DimensionX?duplicate=true">
<img src="https://huggingface.co./datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
<a href="https://huggingface.co./fffiloni">
<img src="https://huggingface.co./datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
</a>
</div>
""")
with gr.Row():
with gr.Column():
image_in = gr.Image(label="Image Input", type="filepath")
prompt = gr.Textbox(label="Prompt")
orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
submit_btn = gr.Button("Submit")
with gr.Column():
video_out = gr.Video(label="Video output")
examples = gr.Examples(
examples = [
[
"https://huggingface.co./datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background.",
"Left",
"./examples/output_astronaut_left.mp4"
],
[
"https://huggingface.co./datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background.",
"Up",
"./examples/output_astronaut_up.mp4"
]
],
inputs=[image_in, prompt, orbit_type, video_out]
)
submit_btn.click(
fn=infer,
inputs=[image_in, prompt, orbit_type],
outputs=[video_out]
)
demo.queue().launch(show_error=True, show_api=False, ssr_mode=False) |