ofai-it2v2 / app.py
fantaxy's picture
Update app.py
5355a5a verified
raw
history blame
5.92 kB
import gradio as gr
from gradio_client import Client, handle_file
import os
import logging
import json
from datetime import datetime
import tempfile
import numpy as np
from PIL import Image
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
# API ํด๋ผ์ด์–ธํŠธ ์„ค์ •
api_client = Client("http://211.233.58.202:7960/")
# ๊ฐค๋Ÿฌ๋ฆฌ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
GALLERY_DIR = "gallery"
GALLERY_JSON = "gallery.json"
# ๊ฐค๋Ÿฌ๋ฆฌ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
os.makedirs(GALLERY_DIR, exist_ok=True)
def save_to_gallery(video_path, prompt):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
new_video_path = os.path.join(GALLERY_DIR, f"{timestamp}.mp4")
# ๋น„๋””์˜ค ํŒŒ์ผ ๋ณต์‚ฌ
with open(video_path, "rb") as src, open(new_video_path, "wb") as dst:
dst.write(src.read())
# ๊ฐค๋Ÿฌ๋ฆฌ ์ •๋ณด ์ €์žฅ
gallery_info = {
"video": new_video_path,
"prompt": prompt,
"timestamp": timestamp
}
if os.path.exists(GALLERY_JSON):
with open(GALLERY_JSON, "r") as f:
gallery = json.load(f)
else:
gallery = []
gallery.append(gallery_info)
with open(GALLERY_JSON, "w") as f:
json.dump(gallery, f, indent=2)
return new_video_path
def load_gallery():
if os.path.exists(GALLERY_JSON):
with open(GALLERY_JSON, "r") as f:
gallery = json.load(f)
return [(item["video"], item["prompt"]) for item in reversed(gallery)]
return []
def respond(image, prompt, steps, cfg_scale, eta, fs, seed, video_length):
logging.info(f"Received prompt: {prompt}, steps: {steps}, cfg_scale: {cfg_scale}, "
f"eta: {eta}, fs: {fs}, seed: {seed}, video_length: {video_length}")
try:
# ์ด๋ฏธ์ง€๋ฅผ ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
if isinstance(image, np.ndarray):
img = Image.fromarray(image.astype('uint8'), 'RGB')
img.save(temp_file.name)
else:
temp_file.write(image)
temp_file_path = temp_file.name
# ๋น„๋””์˜ค ์ƒ์„ฑ ์š”์ฒญ
result = api_client.predict(
image=temp_file_path,
prompt=prompt,
steps=steps,
cfg_scale=cfg_scale,
eta=eta,
fs=fs,
seed=seed,
video_length=video_length,
api_name="/infer"
)
logging.info("API response received: %s", result)
# ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
os.unlink(temp_file_path)
# ๊ฒฐ๊ณผ ํ™•์ธ ๋ฐ ์ฒ˜๋ฆฌ
if isinstance(result, str) and result.endswith('.mp4'):
saved_video_path = save_to_gallery(result, prompt)
return saved_video_path
else:
raise ValueError("Unexpected API response format")
except Exception as e:
logging.error("Error during API request: %s", str(e))
return "Failed to generate video due to an error."
css = """
footer {
visibility: hidden;
}
"""
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์œ„ํ•œ ์˜ˆ์ œ ํ”„๋กฌํ”„ํŠธ
examples = [
["A glamorous young woman with long, wavy blonde hair and smokey eye makeup, posing in a luxury hotel room. Sheโ€™s wearing a sparkly gold cocktail dress and holding up a white card with 'openfree.ai' written on it in elegant calligraphy. Soft, warm lighting creates a luxurious atmosphere. ", "q1.webp"],
["A fantasy map of a fictional world, with detailed terrain and cities.", "q19.webp"]
]
def use_prompt(prompt):
return prompt
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
with gr.Tab("Generate"):
with gr.Row():
input_image = gr.Image(label="Upload an image")
input_text = gr.Textbox(label="Enter your prompt for video generation")
output_video = gr.Video(label="Generated Video")
with gr.Row():
steps = gr.Slider(minimum=1, maximum=100, step=1, label="Steps", value=30)
cfg_scale = gr.Slider(minimum=1, maximum=15, step=0.1, label="CFG Scale", value=3.5)
eta = gr.Slider(minimum=0, maximum=1, step=0.1, label="ETA", value=1)
fs = gr.Slider(minimum=1, maximum=30, step=1, label="FPS", value=8)
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="Seed", value=123)
video_length = gr.Slider(minimum=1, maximum=10, step=1, label="Video Length (seconds)", value=2)
with gr.Row():
for prompt, image_file in examples:
with gr.Column():
gr.Image(image_file, label=prompt[:50] + "...") # ํ”„๋กฌํ”„ํŠธ์˜ ์ฒ˜์Œ 50์ž๋งŒ ํ‘œ์‹œ
gr.Button("Use this prompt").click(
fn=use_prompt,
inputs=[],
outputs=input_text,
api_name=False
).then(
lambda x=prompt: x,
inputs=[],
outputs=input_text
)
with gr.Tab("Gallery"):
gallery = gr.Gallery(
label="Generated Videos",
show_label=False,
elem_id="gallery",
columns=[5],
rows=[3],
object_fit="contain",
height="auto"
)
refresh_btn = gr.Button("Refresh Gallery")
def update_gallery():
return load_gallery()
refresh_btn.click(fn=update_gallery, inputs=None, outputs=gallery)
demo.load(fn=update_gallery, inputs=None, outputs=gallery)
input_text.submit(
fn=respond,
inputs=[input_image, input_text, steps, cfg_scale, eta, fs, seed, video_length],
outputs=output_video
).then(
fn=update_gallery,
inputs=None,
outputs=gallery
)
if __name__ == "__main__":
demo.launch()