vggt / app.py
JianyuanWang's picture
update
2529861
raw
history blame
11.5 kB
import os
import cv2
import torch
import numpy as np
import gradio as gr
import sys
import os
import socket
import webbrowser
sys.path.append('vggt/')
import shutil
from datetime import datetime
from demo_hf import demo_fn #, initialize_model
from omegaconf import DictConfig, OmegaConf
import glob
import gc
import time
from viser_fn import viser_wrapper
from gradio_util import demo_predictions_to_glb
from hydra.utils import instantiate
import spaces
# def get_free_port():
# """Get a free port using socket."""
# # return 80
# # return 8080
# # return 10088 # for debugging
# # return 7860
# # return 7888
# with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# s.bind(('', 0))
# port = s.getsockname()[1]
# return port
cfg_file = "config/base.yaml"
cfg = OmegaConf.load(cfg_file)
vggt_model = instantiate(cfg, _recursive_=False)
_VGGT_URL = "https://huggingface.co./facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt"
# Reload vggt_model
pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL)
if "vggt_model" in pretrain_model:
model_dict = pretrain_model["vggt_model"]
vggt_model.load_state_dict(model_dict, strict=False)
else:
vggt_model.load_state_dict(pretrain_model, strict=True)
# @torch.inference_mode()
@spaces.GPU(duration=240)
def vggt_demo(
input_video,
input_image,
conf_thres=3.0,
frame_filter="all",
mask_black_bg=False,
):
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
target_dir_images = target_dir + "/images"
os.makedirs(target_dir_images)
if input_video is not None:
if not isinstance(input_video, str):
input_video = input_video["video"]["path"]
cfg_file = "config/base.yaml"
cfg = OmegaConf.load(cfg_file)
if input_image is not None:
input_image = sorted(input_image)
for file_name in input_image:
shutil.copy(file_name, target_dir_images)
elif input_video is not None:
vs = cv2.VideoCapture(input_video)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_rate = 1
frame_interval = int(fps * frame_rate)
video_frame_num = 0
count = 0
while True:
(gotit, frame) = vs.read()
count +=1
if not gotit:
break
if count % frame_interval == 0:
cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
video_frame_num+=1
else:
return None, "Uploading not finished or Incorrect input format", None, None
all_files = sorted(os.listdir(target_dir_images))
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
# Update frame_filter choices
frame_filter_choices = ["All"] + all_files
print(f"Files have been copied to {target_dir_images}")
cfg.SCENE_DIR = target_dir
print("Running demo_fn")
with torch.no_grad():
predictions = demo_fn(cfg, vggt_model)
predictions["pred_extrinsic_list"] = None
print("Saving predictions")
prediction_save_path = f"{target_dir}/predictions.npz"
np.savez(prediction_save_path, **predictions)
glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb"
glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg)
glbscene.export(file_obj=glbfile)
del predictions
gc.collect()
torch.cuda.empty_cache()
print(input_image)
print(input_video)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
# Return None for the 3D vggt_model (since we're using viser) and the viser URL
# viser_url = f"Viser visualization is ready at: http://localhost:{viser_port}"
# print(viser_url) # Debug print
log = "Success. Waiting for visualization."
return glbfile, log, target_dir, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
def update_visualization(target_dir, conf_thres, frame_filter, mask_black_bg):
loaded = np.load(f"{target_dir}/predictions.npz", allow_pickle=True)
# predictions = np.load(f"{target_dir}/predictions.npz", allow_pickle=True)
# predictions["arr_0"]
# for key in predictions.files: print(key)
predictions = {key: loaded[key] for key in loaded.keys()}
glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb"
if not os.path.exists(glbfile):
glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg)
glbscene.export(file_obj=glbfile)
return glbfile, "Updating Visualization", target_dir
statue_video = "examples/videos/statue_video.mp4"
apple_video = "examples/videos/apple_video.mp4"
british_museum_video = "examples/videos/british_museum_video.mp4"
cake_video = "examples/videos/cake_video.mp4"
bonsai_video = "examples/videos/bonsai_video.mp4"
face_video = "examples/videos/in2n_face_video.mp4"
counter_video = "examples/videos/in2n_counter_video.mp4"
horns_video = "examples/videos/llff_horns_video.mp4"
person_video = "examples/videos/in2n_person_video.mp4"
flower_video = "examples/videos/llff_flower_video.mp4"
fern_video = "examples/videos/llff_fern_video.mp4"
drums_video = "examples/videos/drums_video.mp4"
kitchen_video = "examples/videos/kitchen_video.mp4"
###########################################################################################
apple_images = glob.glob(f'examples/apple/images/*')
bonsai_images = glob.glob(f'examples/bonsai/images/*')
cake_images = glob.glob(f'examples/cake/images/*')
british_museum_images = glob.glob(f'examples/british_museum/images/*')
face_images = glob.glob(f'examples/in2n_face/images/*')
counter_images = glob.glob(f'examples/in2n_counter/images/*')
horns_images = glob.glob(f'examples/llff_horns/images/*')
person_images = glob.glob(f'examples/in2n_person/images/*')
flower_images = glob.glob(f'examples/llff_flower/images/*')
fern_images = glob.glob(f'examples/llff_fern/images/*')
statue_images = glob.glob(f'examples/statue/images/*')
drums_images = glob.glob(f'examples/drums/images/*')
kitchen_images = glob.glob(f'examples/kitchen/images/*')
###########################################################################################
with gr.Blocks() as demo:
gr.Markdown("""
# ๐Ÿ›๏ธ VGGT: Visual Geometry Grounded Transformer
<div style="font-size: 16px; line-height: 1.5;">
<p><strong>Alpha version</strong> (under active development)</p>
<p>Upload a video or images to create a 3D reconstruction. Once your media appears in the left panel, click the "Reconstruct" button to begin processing.</p>
<h3>Usage Tips:</h3>
<ol>
<li>After reconstruction, you can fine-tune the visualization by adjusting the confidence threshold or selecting specific frames to display, then click "Update Visualization".</li>
<li>Performance note: While the model itself processes quickly (~0.2 seconds), initial setup and visualization may take longer. First-time use requires downloading model weights, and rendering dense point clouds can be resource-intensive.</li>
<li>Known limitation: The model currently exhibits inconsistent behavior with videos centered around human subjects. This issue is being addressed in upcoming updates.</li>
</ol>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
input_video = gr.Video(label="Upload Video", interactive=True)
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
with gr.Column(scale=3):
with gr.Column():
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)**")
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
# reconstruction_output = gr.Model3D(label="3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)", height=520, zoom_speed=0.5, pan_speed=0.5)
# Move these controls to a new row above the log output
with gr.Row():
conf_thres = gr.Slider(minimum=0.1, maximum=20.0, value=3.0, step=0.1, label="Conf Thres")
frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
log_output = gr.Textbox(label="Log")
# Add a hidden textbox for target_dir
target_dir_output = gr.Textbox(label="Target Dir", visible=False)
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1)
revisual_btn = gr.Button("Update Visualization", scale=1)
clear_btn = gr.ClearButton([input_video, input_images, reconstruction_output, log_output, target_dir_output], scale=1) #Modified reconstruction_output
examples = [
[counter_video, counter_images, 1.5, "All", False],
[flower_video, flower_images, 1.5, "All", False],
[kitchen_video, kitchen_images, 3, "All", False],
[fern_video, fern_images, 1.5, "All", False],
# [person_video, person_images],
# [statue_video, statue_images],
# [drums_video, drums_images],
# [horns_video, horns_images, 1.5, "All", False],
# [apple_video, apple_images],
# [bonsai_video, bonsai_images],
]
gr.Examples(examples=examples,
inputs=[input_video, input_images, conf_thres, frame_filter, mask_black_bg],
outputs=[reconstruction_output, log_output, target_dir_output, frame_filter], # Added frame_filter
fn=vggt_demo, # Use our wrapper function
cache_examples=False,
examples_per_page=50,
)
submit_btn.click(
vggt_demo, # Use the same wrapper function
[input_video, input_images, conf_thres, frame_filter, mask_black_bg],
[reconstruction_output, log_output, target_dir_output, frame_filter], # Added frame_filter to outputs
# concurrency_limit=1
)
revisual_btn.click(
update_visualization,
[target_dir_output, conf_thres, frame_filter, mask_black_bg],
[reconstruction_output, log_output, target_dir_output],
)
# demo.launch(debug=True, share=True)
# demo.launch(server_name="0.0.0.0", server_port=8082, debug=True, share=False)
# demo.queue(max_size=20).launch(show_error=True, share=True)
demo.queue(max_size=20).launch(show_error=True) #, share=True, server_port=7888, server_name="0.0.0.0")
# share=True
# demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)
########################################################################################################################