Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import warnings | |
from huggingface_hub import snapshot_download | |
import gradio as gr | |
from glob import glob | |
import shutil | |
import torch | |
import numpy as np | |
from PIL import Image | |
from einops import rearrange | |
import argparse | |
# Suppress warnings | |
warnings.simplefilter('ignore', category=UserWarning) | |
warnings.simplefilter('ignore', category=FutureWarning) | |
warnings.simplefilter('ignore', category=DeprecationWarning) | |
def download_models(): | |
# Create weights directory if it doesn't exist | |
os.makedirs("weights", exist_ok=True) | |
os.makedirs("weights/hunyuanDiT", exist_ok=True) | |
# Download Hunyuan3D-1 model | |
try: | |
snapshot_download( | |
repo_id="tencent/Hunyuan3D-1", | |
local_dir="./weights", | |
resume_download=True | |
) | |
print("Successfully downloaded Hunyuan3D-1 model") | |
except Exception as e: | |
print(f"Error downloading Hunyuan3D-1: {e}") | |
# Download HunyuanDiT model | |
try: | |
snapshot_download( | |
repo_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled", | |
local_dir="./weights/hunyuanDiT", | |
resume_download=True | |
) | |
print("Successfully downloaded HunyuanDiT model") | |
except Exception as e: | |
print(f"Error downloading HunyuanDiT: {e}") | |
# Download models before starting the app | |
download_models() | |
# Parse arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--use_lite", default=False, action="store_true") | |
parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str) | |
parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str) | |
parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str) | |
parser.add_argument("--save_memory", default=False, action="store_true") | |
parser.add_argument("--device", default="cuda:0", type=str) | |
args = parser.parse_args() | |
# Constants | |
CONST_PORT = 8080 | |
CONST_MAX_QUEUE = 1 | |
CONST_SERVER = '0.0.0.0' | |
CONST_HEADER = ''' | |
<h2><b>Official 🤗 Gradio Demo</b></h2> | |
<h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'> | |
<b>Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation</b></a></h2> | |
''' | |
# Helper functions | |
def get_example_img_list(): | |
print('Loading example img list ...') | |
return sorted(glob('./demos/example_*.png')) | |
def get_example_txt_list(): | |
print('Loading example txt list ...') | |
txt_list = [] | |
for line in open('./demos/example_list.txt'): | |
txt_list.append(line.strip()) | |
return txt_list | |
example_is = get_example_img_list() | |
example_ts = get_example_txt_list() | |
# Import required workers | |
from infer import seed_everything, save_gif | |
from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer | |
# Initialize workers | |
worker_xbg = Removebg() | |
print(f"loading {args.text2image_path}") | |
worker_t2i = Text2Image( | |
pretrain=args.text2image_path, | |
device=args.device, | |
save_memory=args.save_memory | |
) | |
worker_i2v = Image2Views( | |
use_lite=args.use_lite, | |
device=args.device | |
) | |
worker_v23 = Views2Mesh( | |
args.mv23d_cfg_path, | |
args.mv23d_ckt_path, | |
use_lite=args.use_lite, | |
device=args.device | |
) | |
worker_gif = GifRenderer(args.device) | |
# Pipeline stages | |
def stage_0_t2i(text, image, seed, step): | |
os.makedirs('./outputs/app_output', exist_ok=True) | |
exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith(".")) | |
cur_id = min(set(range(30)) - exists) if len(exists) < 30 else 0 | |
if os.path.exists(f"./outputs/app_output/{(cur_id + 1) % 30}"): | |
shutil.rmtree(f"./outputs/app_output/{(cur_id + 1) % 30}") | |
save_folder = f'./outputs/app_output/{cur_id}' | |
os.makedirs(save_folder, exist_ok=True) | |
dst = save_folder + '/img.png' | |
if not text: | |
if image is None: | |
return dst, save_folder | |
image.save(dst) | |
return dst, save_folder | |
image = worker_t2i(text, seed, step) | |
image.save(dst) | |
dst = worker_xbg(image, save_folder) | |
return dst, save_folder | |
def stage_1_xbg(image, save_folder): | |
if isinstance(image, str): | |
image = Image.open(image) | |
dst = save_folder + '/img_nobg.png' | |
rgba = worker_xbg(image) | |
rgba.save(dst) | |
return dst | |
def stage_2_i2v(image, seed, step, save_folder): | |
if isinstance(image, str): | |
image = Image.open(image) | |
gif_dst = save_folder + '/views.gif' | |
res_img, pils = worker_i2v(image, seed, step) | |
save_gif(pils, gif_dst) | |
views_img, cond_img = res_img[0], res_img[1] | |
img_array = np.asarray(views_img, dtype=np.uint8) | |
show_img = rearrange(img_array, '(n h) (m w) c -> (n m) h w c', n=3, m=2) | |
show_img = show_img[worker_i2v.order, ...] | |
show_img = rearrange(show_img, '(n m) h w c -> (n h) (m w) c', n=2, m=3) | |
show_img = Image.fromarray(show_img) | |
return views_img, cond_img, show_img | |
def stage_3_v23(views_pil, cond_pil, seed, save_folder, target_face_count=30000, | |
do_texture_mapping=True, do_render=True): | |
do_texture_mapping = do_texture_mapping or do_render | |
obj_dst = save_folder + '/mesh_with_colors.obj' | |
glb_dst = save_folder + '/mesh.glb' | |
worker_v23( | |
views_pil, | |
cond_pil, | |
seed=seed, | |
save_folder=save_folder, | |
target_face_count=target_face_count, | |
do_texture_mapping=do_texture_mapping | |
) | |
return obj_dst, glb_dst | |
def stage_4_gif(obj_dst, save_folder, do_render_gif=True): | |
if not do_render_gif: | |
return None | |
gif_dst = save_folder + '/output.gif' | |
worker_gif( | |
save_folder + '/mesh.obj', | |
gif_dst_path=gif_dst | |
) | |
return gif_dst | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown(CONST_HEADER) | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=2): | |
with gr.Tab("Text to 3D"): | |
with gr.Column(): | |
text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。', | |
lines=1, max_lines=10, label='Input text') | |
with gr.Row(): | |
textgen_seed = gr.Number(value=0, label="T2I seed", precision=0) | |
textgen_step = gr.Number(value=25, label="T2I step", precision=0) | |
textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0) | |
textgen_STEP = gr.Number(value=50, label="Gen step", precision=0) | |
textgen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0) | |
with gr.Row(): | |
textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False) | |
textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False) | |
textgen_submit = gr.Button("Generate", variant="primary") | |
gr.Examples(examples=example_ts, inputs=[text], label="Txt examples") | |
with gr.Tab("Image to 3D"): | |
with gr.Column(): | |
input_image = gr.Image(label="Input image", width=256, height=256, | |
type="pil", image_mode="RGBA", sources="upload") | |
with gr.Row(): | |
imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0) | |
imggen_STEP = gr.Number(value=50, label="Gen step", precision=0) | |
imggen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0) | |
with gr.Row(): | |
imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False) | |
imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False) | |
imggen_submit = gr.Button("Generate", variant="primary") | |
gr.Examples(examples=example_is, inputs=[input_image], label="Img examples") | |
with gr.Column(scale=3): | |
with gr.Tab("rembg image"): | |
rem_bg_image = gr.Image(label="No background image", width=256, height=256, | |
type="pil", image_mode="RGBA") | |
with gr.Tab("Multi views"): | |
result_image = gr.Image(label="Multi views", type="pil") | |
with gr.Tab("Obj"): | |
result_3dobj = gr.Model3D(label="Output obj") | |
with gr.Tab("Glb"): | |
result_3dglb = gr.Model3D(label="Output glb") | |
with gr.Tab("GIF"): | |
result_gif = gr.Image(label="Rendered GIF") | |
# States | |
none = gr.State(None) | |
save_folder = gr.State() | |
cond_image = gr.State() | |
views_image = gr.State() | |
text_image = gr.State() | |
# Event handlers | |
textgen_submit.click( | |
fn=stage_0_t2i, | |
inputs=[text, none, textgen_seed, textgen_step], | |
outputs=[rem_bg_image, save_folder], | |
).success( | |
fn=stage_2_i2v, | |
inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder], | |
outputs=[views_image, cond_image, result_image], | |
).success( | |
fn=stage_3_v23, | |
inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces, | |
textgen_do_texture_mapping, textgen_do_render_gif], | |
outputs=[result_3dobj, result_3dglb], | |
).success( | |
fn=stage_4_gif, | |
inputs=[result_3dglb, save_folder, textgen_do_render_gif], | |
outputs=[result_gif], | |
) | |
imggen_submit.click( | |
fn=stage_0_t2i, | |
inputs=[none, input_image, textgen_seed, textgen_step], | |
outputs=[text_image, save_folder], | |
).success( | |
fn=stage_1_xbg, | |
inputs=[text_image, save_folder], | |
outputs=[rem_bg_image], | |
).success( | |
fn=stage_2_i2v, | |
inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder], | |
outputs=[views_image, cond_image, result_image], | |
).success( | |
fn=stage_3_v23, | |
inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces, | |
imggen_do_texture_mapping, imggen_do_render_gif], | |
outputs=[result_3dobj, result_3dglb], | |
).success( | |
fn=stage_4_gif, | |
inputs=[result_3dglb, save_folder, imggen_do_render_gif], | |
outputs=[result_gif], | |
) | |
demo.queue(max_size=CONST_MAX_QUEUE) | |
demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT) |