GenMM / app.py
wyysf's picture
add header with title (#1)
88aa0dc
raw
history blame
2.8 kB
import json
import time
import uvicorn
from pathlib import Path
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from dataset.tracks_motion import TracksMotion
from GPS import GPS
import gradio as gr
def _synthesis(synthesis_setting, motion_data):
model = GPS(
init_mode=f"random_synthesis/{synthesis_setting['frames']}",
noise_sigma=synthesis_setting["noise_sigma"],
coarse_ratio=0.2,
pyr_factor=synthesis_setting["pyr_factor"],
num_stages_limit=-1,
silent=True,
device="cpu",
)
synthesized_motion = model.run(
motion_data,
mode="match_and_blend",
ext={
"criteria": {
"type": "PatchCoherentLoss",
"patch_size": synthesis_setting["patch_size"],
"stride": synthesis_setting["stride"]
if "stride" in synthesis_setting.keys()
else 1,
"loop": synthesis_setting["loop"],
"coherent_alpha": synthesis_setting["alpha"]
if synthesis_setting["completeness"]
else None,
},
"optimizer": "match_and_blend",
"num_itrs": synthesis_setting["num_steps"],
},
)
return synthesized_motion
def synthesis(data):
data = json.loads(data)
# create track object
data["setting"]["coarse_ratio"] = -1
motion_data = TracksMotion(data["tracks"], scale=data["scale"])
start = time.time()
synthesized_motion = _synthesis(data["setting"], [motion_data])
end = time.time()
data["time"] = end - start
data["tracks"] = motion_data.parse(synthesized_motion)
return data
intro = """
<h1 style="text-align: center;">
Example-based Motion Synthesis via Generative Motion Matching
</h1>
<h3 style="text-align: center; margin-bottom: 7px;">
<a href="http://weiyuli.xyz/GenMM" target="_blank">Project Page</a> | <a href="https://huggingface.co./papers/2306.00378" target="_blank">Paper</a> | <a href="https://github.com/wyysf-98/GenMM" target="_blank">Code</a>
</h3>
"""
with gr.Blocks() as demo:
gr.HTML(intro)
gr.HTML(
"""<iframe src="/GenMM_demo/" width="100%" height="700px" style="border:none;">"""
)
json_in = gr.JSON(visible=False)
json_out = gr.JSON(visible=False)
btn = gr.Button("Synthesize", visible=False)
btn.click(synthesis, inputs=[json_in], outputs=[json_out], api_name="predict")
app = FastAPI()
static_dir = Path("./GenMM_demo")
app.mount("/GenMM_demo", StaticFiles(directory=static_dir, html=True), name="static")
app = gr.mount_gradio_app(app, demo, path="/")
# serve the app
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)