|
import os |
|
import cv2 |
|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
import sys |
|
import os |
|
import socket |
|
import webbrowser |
|
sys.path.append('vggt/') |
|
import shutil |
|
from datetime import datetime |
|
from demo_hf import demo_fn |
|
from omegaconf import DictConfig, OmegaConf |
|
import glob |
|
import gc |
|
import time |
|
from viser_fn import viser_wrapper |
|
from gradio_util import demo_predictions_to_glb |
|
from hydra.utils import instantiate |
|
import spaces |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg_file = "config/base.yaml" |
|
cfg = OmegaConf.load(cfg_file) |
|
vggt_model = instantiate(cfg, _recursive_=False) |
|
_VGGT_URL = "https://huggingface.co./facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt" |
|
|
|
pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL) |
|
|
|
if "vggt_model" in pretrain_model: |
|
model_dict = pretrain_model["vggt_model"] |
|
vggt_model.load_state_dict(model_dict, strict=False) |
|
else: |
|
vggt_model.load_state_dict(pretrain_model, strict=True) |
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=240) |
|
def vggt_demo( |
|
input_video, |
|
input_image, |
|
conf_thres=3.0, |
|
frame_filter="all", |
|
mask_black_bg=False, |
|
): |
|
start_time = time.time() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
|
|
|
target_dir = f"input_images_{timestamp}" |
|
if os.path.exists(target_dir): |
|
shutil.rmtree(target_dir) |
|
|
|
os.makedirs(target_dir) |
|
target_dir_images = target_dir + "/images" |
|
os.makedirs(target_dir_images) |
|
|
|
|
|
if input_video is not None: |
|
if not isinstance(input_video, str): |
|
input_video = input_video["video"]["path"] |
|
|
|
cfg_file = "config/base.yaml" |
|
cfg = OmegaConf.load(cfg_file) |
|
|
|
if input_image is not None: |
|
input_image = sorted(input_image) |
|
for file_name in input_image: |
|
shutil.copy(file_name, target_dir_images) |
|
elif input_video is not None: |
|
vs = cv2.VideoCapture(input_video) |
|
|
|
fps = vs.get(cv2.CAP_PROP_FPS) |
|
|
|
frame_rate = 1 |
|
frame_interval = int(fps * frame_rate) |
|
|
|
video_frame_num = 0 |
|
count = 0 |
|
|
|
while True: |
|
(gotit, frame) = vs.read() |
|
count +=1 |
|
|
|
if not gotit: |
|
break |
|
|
|
if count % frame_interval == 0: |
|
cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame) |
|
video_frame_num+=1 |
|
else: |
|
return None, "Uploading not finished or Incorrect input format", None, None |
|
|
|
all_files = sorted(os.listdir(target_dir_images)) |
|
|
|
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] |
|
|
|
|
|
frame_filter_choices = ["All"] + all_files |
|
|
|
print(f"Files have been copied to {target_dir_images}") |
|
cfg.SCENE_DIR = target_dir |
|
|
|
print("Running demo_fn") |
|
with torch.no_grad(): |
|
predictions = demo_fn(cfg, vggt_model) |
|
predictions["pred_extrinsic_list"] = None |
|
print("Saving predictions") |
|
|
|
prediction_save_path = f"{target_dir}/predictions.npz" |
|
|
|
np.savez(prediction_save_path, **predictions) |
|
|
|
|
|
glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb" |
|
|
|
|
|
glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg) |
|
glbscene.export(file_obj=glbfile) |
|
|
|
del predictions |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
print(input_image) |
|
print(input_video) |
|
end_time = time.time() |
|
execution_time = end_time - start_time |
|
print(f"Execution time: {execution_time} seconds") |
|
|
|
|
|
|
|
|
|
log = "Success. Waiting for visualization." |
|
return glbfile, log, target_dir, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True) |
|
|
|
|
|
|
|
def update_visualization(target_dir, conf_thres, frame_filter, mask_black_bg): |
|
|
|
loaded = np.load(f"{target_dir}/predictions.npz", allow_pickle=True) |
|
|
|
|
|
|
|
predictions = {key: loaded[key] for key in loaded.keys()} |
|
|
|
glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb" |
|
|
|
if not os.path.exists(glbfile): |
|
glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg) |
|
glbscene.export(file_obj=glbfile) |
|
return glbfile, "Updating Visualization", target_dir |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
statue_video = "examples/videos/statue_video.mp4" |
|
|
|
apple_video = "examples/videos/apple_video.mp4" |
|
british_museum_video = "examples/videos/british_museum_video.mp4" |
|
cake_video = "examples/videos/cake_video.mp4" |
|
bonsai_video = "examples/videos/bonsai_video.mp4" |
|
face_video = "examples/videos/in2n_face_video.mp4" |
|
counter_video = "examples/videos/in2n_counter_video.mp4" |
|
|
|
horns_video = "examples/videos/llff_horns_video.mp4" |
|
person_video = "examples/videos/in2n_person_video.mp4" |
|
|
|
flower_video = "examples/videos/llff_flower_video.mp4" |
|
|
|
fern_video = "examples/videos/llff_fern_video.mp4" |
|
|
|
drums_video = "examples/videos/drums_video.mp4" |
|
|
|
kitchen_video = "examples/videos/kitchen_video.mp4" |
|
|
|
|
|
apple_images = glob.glob(f'examples/apple/images/*') |
|
bonsai_images = glob.glob(f'examples/bonsai/images/*') |
|
cake_images = glob.glob(f'examples/cake/images/*') |
|
british_museum_images = glob.glob(f'examples/british_museum/images/*') |
|
face_images = glob.glob(f'examples/in2n_face/images/*') |
|
counter_images = glob.glob(f'examples/in2n_counter/images/*') |
|
|
|
horns_images = glob.glob(f'examples/llff_horns/images/*') |
|
|
|
person_images = glob.glob(f'examples/in2n_person/images/*') |
|
flower_images = glob.glob(f'examples/llff_flower/images/*') |
|
|
|
fern_images = glob.glob(f'examples/llff_fern/images/*') |
|
statue_images = glob.glob(f'examples/statue/images/*') |
|
|
|
drums_images = glob.glob(f'examples/drums/images/*') |
|
kitchen_images = glob.glob(f'examples/kitchen/images/*') |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(""" |
|
# ๐๏ธ VGGT: Visual Geometry Grounded Transformer |
|
|
|
<div style="font-size: 16px; line-height: 1.5;"> |
|
<p><strong>Alpha version</strong> (under active development)</p> |
|
|
|
<p>Upload a video or images to create a 3D reconstruction. Once your media appears in the left panel, click the "Reconstruct" button to begin processing.</p> |
|
|
|
<h3>Usage Tips:</h3> |
|
<ol> |
|
<li>After reconstruction, you can fine-tune the visualization by adjusting the confidence threshold or selecting specific frames to display, then click "Update Visualization".</li> |
|
<li>Performance note: While the model itself processes quickly (~0.2 seconds), initial setup and visualization may take longer. First-time use requires downloading model weights, and rendering dense point clouds can be resource-intensive.</li> |
|
<li>Known limitation: The model currently exhibits inconsistent behavior with videos centered around human subjects. This issue is being addressed in upcoming updates.</li> |
|
</ol> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_video = gr.Video(label="Upload Video", interactive=True) |
|
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) |
|
|
|
with gr.Column(scale=3): |
|
with gr.Column(): |
|
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)**") |
|
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5) |
|
|
|
|
|
|
|
with gr.Row(): |
|
conf_thres = gr.Slider(minimum=0.1, maximum=20.0, value=3.0, step=0.1, label="Conf Thres") |
|
frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame") |
|
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False) |
|
|
|
log_output = gr.Textbox(label="Log") |
|
|
|
target_dir_output = gr.Textbox(label="Target Dir", visible=False) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Reconstruct", scale=1) |
|
revisual_btn = gr.Button("Update Visualization", scale=1) |
|
clear_btn = gr.ClearButton([input_video, input_images, reconstruction_output, log_output, target_dir_output], scale=1) |
|
|
|
|
|
|
|
|
|
examples = [ |
|
[counter_video, counter_images, 1.5, "All", False], |
|
[flower_video, flower_images, 1.5, "All", False], |
|
[kitchen_video, kitchen_images, 3, "All", False], |
|
[fern_video, fern_images, 1.5, "All", False], |
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
gr.Examples(examples=examples, |
|
inputs=[input_video, input_images, conf_thres, frame_filter, mask_black_bg], |
|
outputs=[reconstruction_output, log_output, target_dir_output, frame_filter], |
|
fn=vggt_demo, |
|
cache_examples=False, |
|
examples_per_page=50, |
|
) |
|
|
|
submit_btn.click( |
|
vggt_demo, |
|
[input_video, input_images, conf_thres, frame_filter, mask_black_bg], |
|
[reconstruction_output, log_output, target_dir_output, frame_filter], |
|
|
|
) |
|
|
|
revisual_btn.click( |
|
update_visualization, |
|
[target_dir_output, conf_thres, frame_filter, mask_black_bg], |
|
[reconstruction_output, log_output, target_dir_output], |
|
) |
|
|
|
|
|
|
|
|
|
demo.queue(max_size=20).launch(show_error=True) |
|
|
|
|
|
|
|
|