Spaces:
Runtime error
Runtime error
""" | |
A model worker executes the model. | |
""" | |
import argparse | |
import asyncio | |
import base64 | |
import io | |
import logging | |
import logging.handlers | |
import os | |
import sys | |
import tempfile | |
import threading | |
import traceback | |
import uuid | |
from io import BytesIO | |
import torch | |
import trimesh | |
import uvicorn | |
from PIL import Image | |
from fastapi import FastAPI, Request, UploadFile | |
from fastapi.responses import JSONResponse, FileResponse | |
from hy3dgen.rembg import BackgroundRemover | |
from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FloaterRemover, DegenerateFaceRemover, FaceReducer | |
from hy3dgen.texgen import Hunyuan3DPaintPipeline | |
from hy3dgen.text2image import HunyuanDiTPipeline | |
LOGDIR = '.' | |
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" | |
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." | |
handler = None | |
def build_logger(logger_name, logger_filename): | |
global handler | |
formatter = logging.Formatter( | |
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
# Set the format of root handlers | |
if not logging.getLogger().handlers: | |
logging.basicConfig(level=logging.INFO) | |
logging.getLogger().handlers[0].setFormatter(formatter) | |
# Redirect stdout and stderr to loggers | |
stdout_logger = logging.getLogger("stdout") | |
stdout_logger.setLevel(logging.INFO) | |
sl = StreamToLogger(stdout_logger, logging.INFO) | |
sys.stdout = sl | |
stderr_logger = logging.getLogger("stderr") | |
stderr_logger.setLevel(logging.ERROR) | |
sl = StreamToLogger(stderr_logger, logging.ERROR) | |
sys.stderr = sl | |
# Get logger | |
logger = logging.getLogger(logger_name) | |
logger.setLevel(logging.INFO) | |
# Add a file handler for all loggers | |
if handler is None: | |
os.makedirs(LOGDIR, exist_ok=True) | |
filename = os.path.join(LOGDIR, logger_filename) | |
handler = logging.handlers.TimedRotatingFileHandler( | |
filename, when='D', utc=True, encoding='UTF-8') | |
handler.setFormatter(formatter) | |
for name, item in logging.root.manager.loggerDict.items(): | |
if isinstance(item, logging.Logger): | |
item.addHandler(handler) | |
return logger | |
class StreamToLogger(object): | |
""" | |
Fake file-like stream object that redirects writes to a logger instance. | |
""" | |
def __init__(self, logger, log_level=logging.INFO): | |
self.terminal = sys.stdout | |
self.logger = logger | |
self.log_level = log_level | |
self.linebuf = '' | |
def __getattr__(self, attr): | |
return getattr(self.terminal, attr) | |
def write(self, buf): | |
temp_linebuf = self.linebuf + buf | |
self.linebuf = '' | |
for line in temp_linebuf.splitlines(True): | |
# From the io.TextIOWrapper docs: | |
# On output, if newline is None, any '\n' characters written | |
# are translated to the system default line separator. | |
# By default sys.stdout.write() expects '\n' newlines and then | |
# translates them so this is still cross platform. | |
if line[-1] == '\n': | |
self.logger.log(self.log_level, line.rstrip()) | |
else: | |
self.linebuf += line | |
def flush(self): | |
if self.linebuf != '': | |
self.logger.log(self.log_level, self.linebuf.rstrip()) | |
self.linebuf = '' | |
def pretty_print_semaphore(semaphore): | |
if semaphore is None: | |
return "None" | |
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" | |
SAVE_DIR = 'gradio_cache' | |
os.makedirs(SAVE_DIR, exist_ok=True) | |
worker_id = str(uuid.uuid4())[:6] | |
logger = build_logger("controller", f"{SAVE_DIR}/controller.log") | |
def load_image_from_base64(image): | |
return Image.open(BytesIO(base64.b64decode(image))) | |
def load_image_from_dir(image: UploadFile): | |
"""Loads an image from a given file path.""" | |
try: | |
with image.file as f: # Ensures file is properly closed after reading | |
image_bytes = f.read() # Read image bytes | |
image = Image.open(io.BytesIO(image_bytes)) # Convert to PIL image | |
return image | |
except Exception as e: | |
return {"error": f"Failed to read image: {str(e)}"} | |
class ModelWorker: | |
def __init__(self, model_path='tencent/Hunyuan3D-2', device='cuda'): | |
self.model_path = model_path | |
self.worker_id = worker_id | |
self.device = device | |
logger.info(f"Loading the model {model_path} on worker {worker_id} ...") | |
self.rembg = BackgroundRemover() | |
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path, cache_dir='content/ditto-api/tencent/Hunyuan3D-2', device=device) | |
# self.pipeline_t2i = HunyuanDiTPipeline('Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled', | |
# device=device) | |
self.pipeline_tex = Hunyuan3DPaintPipeline.from_pretrained(model_path) | |
def get_queue_length(self): | |
if model_semaphore is None: | |
return 0 | |
else: | |
return args.limit_model_concurrency - model_semaphore._value + (len( | |
model_semaphore._waiters) if model_semaphore._waiters is not None else 0) | |
def get_status(self): | |
return { | |
"speed": 1, | |
"queue_length": self.get_queue_length(), | |
} | |
def generate(self, uid, form): | |
params = dict() | |
image = form.get("image") # Returns UploadFile object | |
if image: | |
image = load_image_from_dir(image) | |
image = self.rembg(image) | |
params['image'] = image | |
if 'mesh' in params: | |
mesh = trimesh.load(BytesIO(base64.b64decode(params["mesh"])), file_type='glb') | |
else: | |
seed = params.get("seed", 1234) | |
params['generator'] = torch.Generator(self.device).manual_seed(seed) | |
params['octree_resolution'] = params.get("octree_resolution", 256) | |
params['num_inference_steps'] = params.get("num_inference_steps", 30) | |
params['guidance_scale'] = params.get('guidance_scale', 7.5) | |
params['mc_algo'] = 'mc' | |
mesh = self.pipeline(**params)[0] | |
if params.get('texture', False): | |
mesh = FloaterRemover()(mesh) | |
mesh = DegenerateFaceRemover()(mesh) | |
mesh = FaceReducer()(mesh, max_facenum=params.get('face_count', 40000)) | |
mesh = self.pipeline_tex(mesh, image) | |
# with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as temp_file: | |
# print("Thsi is the pathh ====== %s" %temp_file.name) | |
# mesh.export(temp_file.name) | |
# mesh = trimesh.load(temp_file.name) | |
# save_path = os.path.join(SAVE_DIR, f'{str(uid)}.glb') | |
# mesh.export(save_path) | |
save_path = os.path.join(SAVE_DIR, f'{str(uid)}.glb') | |
print("Thsi is the pathh ====== %s" %save_path) | |
mesh.export(save_path) | |
torch.cuda.empty_cache() | |
return save_path, uid | |
app = FastAPI() | |
async def generate(request: Request): | |
logger.info("Worker generating...") | |
# params = await request.json() | |
form = await request.form() | |
# data = dict(params) # Convert form fields to a dictionary | |
# files = {key: params[key] for key in params if hasattr(params[key], "filename")} # Extract files | |
uid = uuid.uuid4() | |
try: | |
file_path, uid = worker.generate(uid, form) | |
return FileResponse(file_path) | |
except ValueError as e: | |
traceback.print_exc() | |
print("Caught ValueError:", e) | |
ret = { | |
"text": server_error_msg, | |
"error_code": 1, | |
} | |
return JSONResponse(ret, status_code=404) | |
except torch.cuda.CudaError as e: | |
print("Caught torch.cuda.CudaError:", e) | |
ret = { | |
"text": server_error_msg, | |
"error_code": 1, | |
} | |
return JSONResponse(ret, status_code=404) | |
except Exception as e: | |
print("Caught Unknown Error", e) | |
traceback.print_exc() | |
ret = { | |
"text": server_error_msg, | |
"error_code": 1, | |
} | |
return JSONResponse(ret, status_code=404) | |
async def generate(request: Request): | |
logger.info("Worker send...") | |
params = await request.json() | |
uid = uuid.uuid4() | |
threading.Thread(target=worker.generate, args=(uid, params,)).start() | |
ret = {"uid": str(uid)} | |
return JSONResponse(ret, status_code=200) | |
async def status(uid: str): | |
save_file_path = os.path.join(SAVE_DIR, f'{uid}.glb') | |
print(save_file_path, os.path.exists(save_file_path)) | |
if not os.path.exists(save_file_path): | |
response = {'status': 'processing'} | |
return JSONResponse(response, status_code=200) | |
else: | |
base64_str = base64.b64encode(open(save_file_path, 'rb').read()).decode() | |
response = {'status': 'completed', 'model_base64': base64_str} | |
return JSONResponse(response, status_code=200) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=str, default=8081) | |
parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2') | |
parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--limit-model-concurrency", type=int, default=5) | |
args = parser.parse_args() | |
logger.info(f"args: {args}") | |
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) | |
worker = ModelWorker(model_path=args.model_path, device=args.device) | |
uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |