Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio_client import Client, handle_file | |
import os | |
import logging | |
import json | |
from datetime import datetime | |
import tempfile | |
import numpy as np | |
from PIL import Image | |
# ๋ก๊น ์ค์ | |
logging.basicConfig(level=logging.INFO) | |
# API ํด๋ผ์ด์ธํธ ์ค์ | |
api_client = Client("http://211.233.58.202:7960/") | |
# ๊ฐค๋ฌ๋ฆฌ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์ค์ | |
GALLERY_DIR = "gallery" | |
GALLERY_JSON = "gallery.json" | |
# ๊ฐค๋ฌ๋ฆฌ ๋๋ ํ ๋ฆฌ ์์ฑ | |
os.makedirs(GALLERY_DIR, exist_ok=True) | |
def save_to_gallery(video_path, prompt): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
new_video_path = os.path.join(GALLERY_DIR, f"{timestamp}.mp4") | |
# ๋น๋์ค ํ์ผ ๋ณต์ฌ | |
with open(video_path, "rb") as src, open(new_video_path, "wb") as dst: | |
dst.write(src.read()) | |
# ๊ฐค๋ฌ๋ฆฌ ์ ๋ณด ์ ์ฅ | |
gallery_info = { | |
"video": new_video_path, | |
"prompt": prompt, | |
"timestamp": timestamp | |
} | |
if os.path.exists(GALLERY_JSON): | |
with open(GALLERY_JSON, "r") as f: | |
gallery = json.load(f) | |
else: | |
gallery = [] | |
gallery.append(gallery_info) | |
with open(GALLERY_JSON, "w") as f: | |
json.dump(gallery, f, indent=2) | |
return new_video_path | |
def load_gallery(): | |
if os.path.exists(GALLERY_JSON): | |
with open(GALLERY_JSON, "r") as f: | |
gallery = json.load(f) | |
return [(item["video"], item["prompt"]) for item in reversed(gallery)] | |
return [] | |
def respond(image, prompt, steps, cfg_scale, eta, fs, seed, video_length): | |
logging.info(f"Received prompt: {prompt}, steps: {steps}, cfg_scale: {cfg_scale}, " | |
f"eta: {eta}, fs: {fs}, seed: {seed}, video_length: {video_length}") | |
try: | |
# ์ด๋ฏธ์ง๋ฅผ ์์ ํ์ผ๋ก ์ ์ฅ | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
if isinstance(image, np.ndarray): | |
img = Image.fromarray(image.astype('uint8'), 'RGB') | |
img.save(temp_file.name) | |
else: | |
temp_file.write(image) | |
temp_file_path = temp_file.name | |
# ๋น๋์ค ์์ฑ ์์ฒญ | |
result = api_client.predict( | |
image=temp_file_path, | |
prompt=prompt, | |
steps=steps, | |
cfg_scale=cfg_scale, | |
eta=eta, | |
fs=fs, | |
seed=seed, | |
video_length=video_length, | |
api_name="/infer" | |
) | |
logging.info("API response received: %s", result) | |
# ์์ ํ์ผ ์ญ์ | |
os.unlink(temp_file_path) | |
# ๊ฒฐ๊ณผ ํ์ธ ๋ฐ ์ฒ๋ฆฌ | |
if isinstance(result, str) and result.endswith('.mp4'): | |
saved_video_path = save_to_gallery(result, prompt) | |
return saved_video_path | |
else: | |
raise ValueError("Unexpected API response format") | |
except Exception as e: | |
logging.error("Error during API request: %s", str(e)) | |
return "Failed to generate video due to an error." | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
""" | |
# ์ด๋ฏธ์ง ์์ฑ์ ์ํ ์์ ํ๋กฌํํธ | |
examples = [ | |
["A glamorous young woman with long, wavy blonde hair and smokey eye makeup, posing in a luxury hotel room. Sheโs wearing a sparkly gold cocktail dress and holding up a white card with 'openfree.ai' written on it in elegant calligraphy. Soft, warm lighting creates a luxurious atmosphere. ", "q1.webp"], | |
["A fantasy map of a fictional world, with detailed terrain and cities.", "q19.webp"] | |
] | |
def use_prompt(prompt): | |
return prompt | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo: | |
with gr.Tab("Generate"): | |
with gr.Row(): | |
input_image = gr.Image(label="Upload an image") | |
input_text = gr.Textbox(label="Enter your prompt for video generation") | |
output_video = gr.Video(label="Generated Video") | |
with gr.Row(): | |
steps = gr.Slider(minimum=1, maximum=100, step=1, label="Steps", value=30) | |
cfg_scale = gr.Slider(minimum=1, maximum=15, step=0.1, label="CFG Scale", value=3.5) | |
eta = gr.Slider(minimum=0, maximum=1, step=0.1, label="ETA", value=1) | |
fs = gr.Slider(minimum=1, maximum=30, step=1, label="FPS", value=8) | |
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="Seed", value=123) | |
video_length = gr.Slider(minimum=1, maximum=10, step=1, label="Video Length (seconds)", value=2) | |
with gr.Row(): | |
for prompt, image_file in examples: | |
with gr.Column(): | |
gr.Image(image_file, label=prompt[:50] + "...") # ํ๋กฌํํธ์ ์ฒ์ 50์๋ง ํ์ | |
gr.Button("Use this prompt").click( | |
fn=use_prompt, | |
inputs=[], | |
outputs=input_text, | |
api_name=False | |
).then( | |
lambda x=prompt: x, | |
inputs=[], | |
outputs=input_text | |
) | |
with gr.Tab("Gallery"): | |
gallery = gr.Gallery( | |
label="Generated Videos", | |
show_label=False, | |
elem_id="gallery", | |
columns=[5], | |
rows=[3], | |
object_fit="contain", | |
height="auto" | |
) | |
refresh_btn = gr.Button("Refresh Gallery") | |
def update_gallery(): | |
return load_gallery() | |
refresh_btn.click(fn=update_gallery, inputs=None, outputs=gallery) | |
demo.load(fn=update_gallery, inputs=None, outputs=gallery) | |
input_text.submit( | |
fn=respond, | |
inputs=[input_image, input_text, steps, cfg_scale, eta, fs, seed, video_length], | |
outputs=output_video | |
).then( | |
fn=update_gallery, | |
inputs=None, | |
outputs=gallery | |
) | |
if __name__ == "__main__": | |
demo.launch() | |