Flux-TRELLIS / app.py
gokaygokay's picture
Update app.py
521472f verified
raw
history blame
10.5 kB
import spaces
import os
import warnings
from huggingface_hub import snapshot_dowload
import gradio as gr
from glob import glob
import shutil
import torch
import numpy as np
from PIL import Image
from einops import rearrange
import argparse
# Suppress warnings
warnings.simplefilter('ignore', category=UserWarning)
warnings.simplefilter('ignore', category=FutureWarning)
warnings.simplefilter('ignore', category=DeprecationWarning)
def download_models():
# Create weights directory if it doesn't exist
os.makedirs("weights", exist_ok=True)
os.makedirs("weights/hunyuanDiT", exist_ok=True)
# Download Hunyuan3D-1 model
try:
snapshot_dowload(
repo_id="tencent/Hunyuan3D-1",
local_dir="./weights",
resume_download=True
)
print("Successfully downloaded Hunyuan3D-1 model")
except Exception as e:
print(f"Error downloading Hunyuan3D-1: {e}")
# Download HunyuanDiT model
try:
snapshot_dowload(
repo_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled",
local_dir="./weights/hunyuanDiT",
resume_download=True
)
print("Successfully downloaded HunyuanDiT model")
except Exception as e:
print(f"Error downloading HunyuanDiT: {e}")
# Download models before starting the app
download_models()
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--use_lite", default=False, action="store_true")
parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str)
parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str)
parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str)
parser.add_argument("--save_memory", default=False, action="store_true")
parser.add_argument("--device", default="cuda:0", type=str)
args = parser.parse_args()
# Constants
CONST_PORT = 8080
CONST_MAX_QUEUE = 1
CONST_SERVER = '0.0.0.0'
CONST_HEADER = '''
<h2><b>Official 🤗 Gradio Demo</b></h2>
<h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>
<b>Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation</b></a></h2>
'''
# Helper functions
def get_example_img_list():
print('Loading example img list ...')
return sorted(glob('./demos/example_*.png'))
def get_example_txt_list():
print('Loading example txt list ...')
txt_list = []
for line in open('./demos/example_list.txt'):
txt_list.append(line.strip())
return txt_list
example_is = get_example_img_list()
example_ts = get_example_txt_list()
# Import required workers
from infer import seed_everything, save_gif
from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
# Initialize workers
worker_xbg = Removebg()
print(f"loading {args.text2image_path}")
worker_t2i = Text2Image(
pretrain=args.text2image_path,
device=args.device,
save_memory=args.save_memory
)
worker_i2v = Image2Views(
use_lite=args.use_lite,
device=args.device
)
worker_v23 = Views2Mesh(
args.mv23d_cfg_path,
args.mv23d_ckt_path,
use_lite=args.use_lite,
device=args.device
)
worker_gif = GifRenderer(args.device)
# Pipeline stages
@spaces.GPU
def stage_0_t2i(text, image, seed, step):
os.makedirs('./outputs/app_output', exist_ok=True)
exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
cur_id = min(set(range(30)) - exists) if len(exists) < 30 else 0
if os.path.exists(f"./outputs/app_output/{(cur_id + 1) % 30}"):
shutil.rmtree(f"./outputs/app_output/{(cur_id + 1) % 30}")
save_folder = f'./outputs/app_output/{cur_id}'
os.makedirs(save_folder, exist_ok=True)
dst = save_folder + '/img.png'
if not text:
if image is None:
return dst, save_folder
image.save(dst)
return dst, save_folder
image = worker_t2i(text, seed, step)
image.save(dst)
dst = worker_xbg(image, save_folder)
return dst, save_folder
@spaces.GPU
def stage_1_xbg(image, save_folder):
if isinstance(image, str):
image = Image.open(image)
dst = save_folder + '/img_nobg.png'
rgba = worker_xbg(image)
rgba.save(dst)
return dst
@spaces.GPU
def stage_2_i2v(image, seed, step, save_folder):
if isinstance(image, str):
image = Image.open(image)
gif_dst = save_folder + '/views.gif'
res_img, pils = worker_i2v(image, seed, step)
save_gif(pils, gif_dst)
views_img, cond_img = res_img[0], res_img[1]
img_array = np.asarray(views_img, dtype=np.uint8)
show_img = rearrange(img_array, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
show_img = show_img[worker_i2v.order, ...]
show_img = rearrange(show_img, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
show_img = Image.fromarray(show_img)
return views_img, cond_img, show_img
@spaces.GPU
def stage_3_v23(views_pil, cond_pil, seed, save_folder, target_face_count=30000,
do_texture_mapping=True, do_render=True):
do_texture_mapping = do_texture_mapping or do_render
obj_dst = save_folder + '/mesh_with_colors.obj'
glb_dst = save_folder + '/mesh.glb'
worker_v23(
views_pil,
cond_pil,
seed=seed,
save_folder=save_folder,
target_face_count=target_face_count,
do_texture_mapping=do_texture_mapping
)
return obj_dst, glb_dst
@spaces.GPU
def stage_4_gif(obj_dst, save_folder, do_render_gif=True):
if not do_render_gif:
return None
gif_dst = save_folder + '/output.gif'
worker_gif(
save_folder + '/mesh.obj',
gif_dst_path=gif_dst
)
return gif_dst
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown(CONST_HEADER)
with gr.Row(variant="panel"):
with gr.Column(scale=2):
with gr.Tab("Text to 3D"):
with gr.Column():
text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。',
lines=1, max_lines=10, label='Input text')
with gr.Row():
textgen_seed = gr.Number(value=0, label="T2I seed", precision=0)
textgen_step = gr.Number(value=25, label="T2I step", precision=0)
textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
textgen_STEP = gr.Number(value=50, label="Gen step", precision=0)
textgen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
with gr.Row():
textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False)
textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False)
textgen_submit = gr.Button("Generate", variant="primary")
gr.Examples(examples=example_ts, inputs=[text], label="Txt examples")
with gr.Tab("Image to 3D"):
with gr.Column():
input_image = gr.Image(label="Input image", width=256, height=256,
type="pil", image_mode="RGBA", sources="upload")
with gr.Row():
imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
imggen_STEP = gr.Number(value=50, label="Gen step", precision=0)
imggen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
with gr.Row():
imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False)
imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False)
imggen_submit = gr.Button("Generate", variant="primary")
gr.Examples(examples=example_is, inputs=[input_image], label="Img examples")
with gr.Column(scale=3):
with gr.Tab("rembg image"):
rem_bg_image = gr.Image(label="No background image", width=256, height=256,
type="pil", image_mode="RGBA")
with gr.Tab("Multi views"):
result_image = gr.Image(label="Multi views", type="pil")
with gr.Tab("Obj"):
result_3dobj = gr.Model3D(label="Output obj")
with gr.Tab("Glb"):
result_3dglb = gr.Model3D(label="Output glb")
with gr.Tab("GIF"):
result_gif = gr.Image(label="Rendered GIF")
# States
none = gr.State(None)
save_folder = gr.State()
cond_image = gr.State()
views_image = gr.State()
text_image = gr.State()
# Event handlers
textgen_submit.click(
fn=stage_0_t2i,
inputs=[text, none, textgen_seed, textgen_step],
outputs=[rem_bg_image, save_folder],
).success(
fn=stage_2_i2v,
inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
outputs=[views_image, cond_image, result_image],
).success(
fn=stage_3_v23,
inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces,
textgen_do_texture_mapping, textgen_do_render_gif],
outputs=[result_3dobj, result_3dglb],
).success(
fn=stage_4_gif,
inputs=[result_3dglb, save_folder, textgen_do_render_gif],
outputs=[result_gif],
)
imggen_submit.click(
fn=stage_0_t2i,
inputs=[none, input_image, textgen_seed, textgen_step],
outputs=[text_image, save_folder],
).success(
fn=stage_1_xbg,
inputs=[text_image, save_folder],
outputs=[rem_bg_image],
).success(
fn=stage_2_i2v,
inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
outputs=[views_image, cond_image, result_image],
).success(
fn=stage_3_v23,
inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces,
imggen_do_texture_mapping, imggen_do_render_gif],
outputs=[result_3dobj, result_3dglb],
).success(
fn=stage_4_gif,
inputs=[result_3dglb, save_folder, imggen_do_render_gif],
outputs=[result_gif],
)
demo.queue(max_size=CONST_MAX_QUEUE)
demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)