Spaces:
Sleeping
Sleeping
import torch | |
import os | |
import numpy as np | |
import cv2 | |
import PIL.Image as pil_img | |
import subprocess | |
subprocess.run( | |
'pip install networkx==2.5' | |
.split() | |
) | |
import gradio as gr | |
import trimesh | |
import plotly.graph_objects as go | |
from models.deco import DECO | |
from common import constants | |
if torch.cuda.is_available(): | |
device = torch.device('cuda') | |
else: | |
device = torch.device('cpu') | |
description = ''' | |
### DECO: Dense Estimation of 3D Human-Scene Contact in the Wild (ICCV 2023, Oral) | |
<table> | |
<th width="20%"> | |
<ul> | |
<li><strong><a href="https://deco.is.tue.mpg.de/">Homepage</a></strong> | |
<li><strong><a href="https://github.com/sha2nkt/deco">Code</a></strong> | |
<li><strong><a href="https://openaccess.thecvf.com/content/ICCV2023/html/Tripathi_DECO_Dense_Estimation_of_3D_Human-Scene_Contact_In_The_Wild_ICCV_2023_paper.html">Paper</a></strong> | |
</ul> | |
<br> | |
<ul> | |
<li><strong>Colab Notebook</strong> <a href='https://colab.research.google.com/drive/1fTQdI2AHEKlwYG9yIb2wqicIMhAa067_?usp=sharing'><img style="display: inline-block;" src='https://colab.research.google.com/assets/colab-badge.svg' alt='Google Colab'></a></li> | |
</ul> | |
<br> | |
<iframe src="https://ghbtns.com/github-btn.html?user=sha2nkt&repo=deco&type=star&count=true&v=2&size=small" frameborder="0" scrolling="0" width="100" height="20"></iframe> | |
</th> | |
</table> | |
#### Citation | |
``` | |
@InProceedings{tripathi2023deco, | |
author = {Tripathi, Shashank and Chatterjee, Agniv and Passy, Jean-Claude and Yi, Hongwei and Tzionas, Dimitrios and Black, Michael J.}, | |
title = {{DECO}: Dense Estimation of {3D} Human-Scene Contact In The Wild}, | |
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, | |
month = {October}, | |
year = {2023}, | |
pages = {8001-8013} | |
} | |
``` | |
<details> | |
<summary>More</summary> | |
#### Acknowledgments: | |
- [ECON](https://huggingface.co./spaces/Yuliang/ECON) | |
</details> | |
''' | |
DEFAULT_LIGHTING = dict( | |
ambient=0.6, | |
diffuse=0.5, | |
fresnel=0.01, | |
specular=0.1, | |
roughness=0.001) | |
DEFAULT_LIGHT_POSITION = dict(x=6, y=0, z=10) | |
def initiate_model(model_path): | |
deco_model = DECO('hrnet', True, device) | |
print(f'Loading weights from {model_path}') | |
checkpoint = torch.load(model_path) | |
deco_model.load_state_dict(checkpoint['deco'], strict=True) | |
deco_model.eval() | |
return deco_model | |
def create_layout(dummy, camera=None): | |
if camera is None: | |
camera = dict( | |
up=dict(x=0, y=1, z=0), | |
center=dict(x=0, y=0, z=0), | |
eye=dict(x=dummy.x.mean(), y=0, z=3), | |
projection=dict(type='perspective')) | |
layout = dict( | |
scene={ | |
"xaxis": { | |
'showgrid': False, | |
'zeroline': False, | |
'visible': False, | |
"range": [dummy.x.min(), dummy.x.max()] | |
}, | |
"yaxis": { | |
'showgrid': False, | |
'zeroline': False, | |
'visible': False, | |
"range": [dummy.y.min(), dummy.y.max()] | |
}, | |
"zaxis": { | |
'showgrid': False, | |
'zeroline': False, | |
'visible': False, | |
"range": [dummy.z.min(), dummy.z.max()] | |
}, | |
}, | |
autosize=False, | |
width=750, height=1000, | |
scene_camera=camera, | |
scene_aspectmode="data", | |
clickmode="event+select", | |
margin={'l': 0, 't': 0} | |
) | |
return layout | |
def create_fig(dummy, camera=None): | |
fig = go.Figure( | |
data=dummy.mesh_3d(), | |
layout=create_layout(dummy, camera)) | |
return fig | |
class Dummy: | |
def __init__(self, mesh_path): | |
"""A simple polygonal dummy with colored patches.""" | |
self._load_trimesh(mesh_path) | |
def _load_trimesh(self, path): | |
"""Load a mesh given a path to a .PLY file.""" | |
self._trimesh = trimesh.load(path, process=False) | |
self._vertices = np.array(self._trimesh.vertices) | |
self._faces = np.array(self._trimesh.faces) | |
self.colors = self._trimesh.visual.vertex_colors | |
def vertices(self): | |
"""All the mesh vertices.""" | |
return self._vertices | |
def faces(self): | |
"""All the mesh faces.""" | |
return self._faces | |
def n_vertices(self): | |
"""Number of vertices in a mesh.""" | |
return self._vertices.shape[0] | |
def n_faces(self): | |
"""Number of faces in a mesh.""" | |
return self._faces.shape[0] | |
def x(self): | |
"""An array of vertex x coordinates""" | |
return self._vertices[:, 0] | |
def y(self): | |
"""An array of vertex y coordinates""" | |
return self._vertices[:, 1] | |
def z(self): | |
"""An array of vertex z coordinates""" | |
return self._vertices[:, 2] | |
def i(self): | |
"""An array of the first face vertices""" | |
return self._faces[:, 0] | |
def j(self): | |
"""An array of the second face vertices""" | |
return self._faces[:, 1] | |
def k(self): | |
"""An array of the third face vertices""" | |
return self._faces[:, 2] | |
def default_selection(self): | |
"""Default patch selection mask.""" | |
return dict(vertices=[]) | |
def mesh_3d( | |
self, | |
lighting=DEFAULT_LIGHTING, | |
light_position=DEFAULT_LIGHT_POSITION | |
): | |
"""Construct a Mesh3D object give a clickmask for patch coloring.""" | |
return go.Mesh3d( | |
x=self.x, y=self.y, z=self.z, | |
i=self.i, j=self.j, k=self.k, | |
vertexcolor=self.colors, | |
lighting=lighting, | |
lightposition=light_position, | |
hoverinfo='none') | |
def main(pil_img, out_dir='demo_out', model_path='checkpoint/deco_best.pth', mesh_colour=[130, 130, 130, 255], annot_colour=[0, 255, 0, 255]): | |
deco_model = initiate_model(model_path) | |
smpl_path = os.path.join(constants.SMPL_MODEL_DIR, 'smpl_neutral_tpose.ply') | |
img = np.array(pil_img) | |
img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
img = img.transpose(2,0,1)/255.0 | |
img = img[np.newaxis,:,:,:] | |
img = torch.tensor(img, dtype = torch.float32).to(device) | |
with torch.no_grad(): | |
cont, _, _ = deco_model(img) | |
cont = cont.detach().cpu().numpy().squeeze() | |
cont_smpl = [] | |
for indx, i in enumerate(cont): | |
if i >= 0.5: | |
cont_smpl.append(indx) | |
img = img.detach().cpu().numpy() | |
img = np.transpose(img[0], (1, 2, 0)) | |
img = img * 255 | |
img = img.astype(np.uint8) | |
contact_smpl = np.zeros((1, 1, 6890)) | |
contact_smpl[0][0][cont_smpl] = 1 | |
body_model_smpl = trimesh.load(smpl_path, process=False) | |
for vert in range(body_model_smpl.visual.vertex_colors.shape[0]): | |
body_model_smpl.visual.vertex_colors[vert] = mesh_colour | |
body_model_smpl.visual.vertex_colors[cont_smpl] = annot_colour | |
mesh_out_dir = os.path.join(out_dir, 'Preds') | |
os.makedirs(mesh_out_dir, exist_ok=True) | |
print(f'Saving mesh to {mesh_out_dir}') | |
body_model_smpl.export(os.path.join(mesh_out_dir, 'pred.obj')) | |
dummy = Dummy(os.path.join(mesh_out_dir, 'pred.obj')) | |
fig = create_fig(dummy) | |
return fig, os.path.join(mesh_out_dir, 'pred.obj') | |
with gr.Blocks(title="DECO", css=".gradio-container") as demo: | |
gr.Markdown(description) | |
gr.HTML("""<h1 style="text-align:center; color:#10768c">DECO Demo</h1>""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input image", type="pil") | |
with gr.Column(): | |
output_image = gr.Plot(label="Renders") | |
output_meshes = gr.File(label="3D meshes") | |
gr.HTML("""<br/>""") | |
with gr.Row(): | |
send_btn = gr.Button("Infer") | |
send_btn.click(fn=main, inputs=[input_image], outputs=[output_image, output_meshes]) | |
example_images = gr.Examples([ | |
['/home/user/app/example_images/213.jpg'], | |
['/home/user/app/example_images/pexels-photo-207569.webp'], | |
['/home/user/app/example_images/pexels-photo-3622517.webp'], | |
['/home/user/app/example_images/pexels-photo-15732209.jpeg'], | |
], | |
inputs=[input_image]) | |
demo.launch(debug=True) |