import os import cv2 import torch import numpy as np import gradio as gr import sys import os import socket import webbrowser sys.path.append('vggt/') import shutil from datetime import datetime from demo_hf import demo_fn #, initialize_model from omegaconf import DictConfig, OmegaConf import glob import gc import time from viser_fn import viser_wrapper from gradio_util import demo_predictions_to_glb from hydra.utils import instantiate import spaces # def get_free_port(): # """Get a free port using socket.""" # # return 80 # # return 8080 # # return 10088 # for debugging # # return 7860 # # return 7888 # with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # s.bind(('', 0)) # port = s.getsockname()[1] # return port cfg_file = "config/base.yaml" cfg = OmegaConf.load(cfg_file) vggt_model = instantiate(cfg, _recursive_=False) _VGGT_URL = "https://huggingface.co./facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt" # Reload vggt_model pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL) if "vggt_model" in pretrain_model: model_dict = pretrain_model["vggt_model"] vggt_model.load_state_dict(model_dict, strict=False) else: vggt_model.load_state_dict(pretrain_model, strict=True) # @torch.inference_mode() @spaces.GPU(duration=240) def vggt_demo( input_video, input_image, conf_thres=3.0, frame_filter="all", mask_black_bg=False, ): start_time = time.time() gc.collect() torch.cuda.empty_cache() # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) target_dir_images = target_dir + "/images" os.makedirs(target_dir_images) if input_video is not None: if not isinstance(input_video, str): input_video = input_video["video"]["path"] cfg_file = "config/base.yaml" cfg = OmegaConf.load(cfg_file) if input_image is not None: input_image = sorted(input_image) for file_name in input_image: shutil.copy(file_name, target_dir_images) elif input_video is not None: vs = cv2.VideoCapture(input_video) fps = vs.get(cv2.CAP_PROP_FPS) frame_rate = 1 frame_interval = int(fps * frame_rate) video_frame_num = 0 count = 0 while True: (gotit, frame) = vs.read() count +=1 if not gotit: break if count % frame_interval == 0: cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame) video_frame_num+=1 else: return None, "Uploading not finished or Incorrect input format", None, None all_files = sorted(os.listdir(target_dir_images)) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] # Update frame_filter choices frame_filter_choices = ["All"] + all_files print(f"Files have been copied to {target_dir_images}") cfg.SCENE_DIR = target_dir print("Running demo_fn") with torch.no_grad(): predictions = demo_fn(cfg, vggt_model) predictions["pred_extrinsic_list"] = None print("Saving predictions") prediction_save_path = f"{target_dir}/predictions.npz" np.savez(prediction_save_path, **predictions) glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb" glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg) glbscene.export(file_obj=glbfile) del predictions gc.collect() torch.cuda.empty_cache() print(input_image) print(input_video) end_time = time.time() execution_time = end_time - start_time print(f"Execution time: {execution_time} seconds") # Return None for the 3D vggt_model (since we're using viser) and the viser URL # viser_url = f"Viser visualization is ready at: http://localhost:{viser_port}" # print(viser_url) # Debug print log = "Success. Waiting for visualization." return glbfile, log, target_dir, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True) def update_visualization(target_dir, conf_thres, frame_filter, mask_black_bg): loaded = np.load(f"{target_dir}/predictions.npz", allow_pickle=True) # predictions = np.load(f"{target_dir}/predictions.npz", allow_pickle=True) # predictions["arr_0"] # for key in predictions.files: print(key) predictions = {key: loaded[key] for key in loaded.keys()} glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb" if not os.path.exists(glbfile): glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg) glbscene.export(file_obj=glbfile) return glbfile, "Updating Visualization", target_dir statue_video = "examples/videos/statue_video.mp4" apple_video = "examples/videos/apple_video.mp4" british_museum_video = "examples/videos/british_museum_video.mp4" cake_video = "examples/videos/cake_video.mp4" bonsai_video = "examples/videos/bonsai_video.mp4" face_video = "examples/videos/in2n_face_video.mp4" counter_video = "examples/videos/in2n_counter_video.mp4" horns_video = "examples/videos/llff_horns_video.mp4" person_video = "examples/videos/in2n_person_video.mp4" flower_video = "examples/videos/llff_flower_video.mp4" fern_video = "examples/videos/llff_fern_video.mp4" drums_video = "examples/videos/drums_video.mp4" kitchen_video = "examples/videos/kitchen_video.mp4" ########################################################################################### apple_images = glob.glob(f'examples/apple/images/*') bonsai_images = glob.glob(f'examples/bonsai/images/*') cake_images = glob.glob(f'examples/cake/images/*') british_museum_images = glob.glob(f'examples/british_museum/images/*') face_images = glob.glob(f'examples/in2n_face/images/*') counter_images = glob.glob(f'examples/in2n_counter/images/*') horns_images = glob.glob(f'examples/llff_horns/images/*') person_images = glob.glob(f'examples/in2n_person/images/*') flower_images = glob.glob(f'examples/llff_flower/images/*') fern_images = glob.glob(f'examples/llff_fern/images/*') statue_images = glob.glob(f'examples/statue/images/*') drums_images = glob.glob(f'examples/drums/images/*') kitchen_images = glob.glob(f'examples/kitchen/images/*') ########################################################################################### with gr.Blocks() as demo: gr.Markdown(""" # 🏛️ VGGT: Visual Geometry Grounded Transformer

Alpha version (under active development)

Upload a video or images to create a 3D reconstruction. Once your media appears in the left panel, click the "Reconstruct" button to begin processing.

Usage Tips:

  1. After reconstruction, you can fine-tune the visualization by adjusting the confidence threshold or selecting specific frames to display, then click "Update Visualization".
  2. Performance note: While the model itself processes quickly (~0.2 seconds), initial setup and visualization may take longer. First-time use requires downloading model weights, and rendering dense point clouds can be resource-intensive.
  3. Known limitation: The model currently exhibits inconsistent behavior with videos centered around human subjects. This issue is being addressed in upcoming updates.
""") with gr.Row(): with gr.Column(scale=1): input_video = gr.Video(label="Upload Video", interactive=True) input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) with gr.Column(scale=3): with gr.Column(): gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)**") reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5) # reconstruction_output = gr.Model3D(label="3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)", height=520, zoom_speed=0.5, pan_speed=0.5) # Move these controls to a new row above the log output with gr.Row(): conf_thres = gr.Slider(minimum=0.1, maximum=20.0, value=3.0, step=0.1, label="Conf Thres") frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame") mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False) log_output = gr.Textbox(label="Log") # Add a hidden textbox for target_dir target_dir_output = gr.Textbox(label="Target Dir", visible=False) with gr.Row(): submit_btn = gr.Button("Reconstruct", scale=1) revisual_btn = gr.Button("Update Visualization", scale=1) clear_btn = gr.ClearButton([input_video, input_images, reconstruction_output, log_output, target_dir_output], scale=1) #Modified reconstruction_output examples = [ [counter_video, counter_images, 1.5, "All", False], [flower_video, flower_images, 1.5, "All", False], [kitchen_video, kitchen_images, 3, "All", False], [fern_video, fern_images, 1.5, "All", False], # [person_video, person_images], # [statue_video, statue_images], # [drums_video, drums_images], # [horns_video, horns_images, 1.5, "All", False], # [apple_video, apple_images], # [bonsai_video, bonsai_images], ] gr.Examples(examples=examples, inputs=[input_video, input_images, conf_thres, frame_filter, mask_black_bg], outputs=[reconstruction_output, log_output, target_dir_output, frame_filter], # Added frame_filter fn=vggt_demo, # Use our wrapper function cache_examples=False, examples_per_page=50, ) submit_btn.click( vggt_demo, # Use the same wrapper function [input_video, input_images, conf_thres, frame_filter, mask_black_bg], [reconstruction_output, log_output, target_dir_output, frame_filter], # Added frame_filter to outputs # concurrency_limit=1 ) revisual_btn.click( update_visualization, [target_dir_output, conf_thres, frame_filter, mask_black_bg], [reconstruction_output, log_output, target_dir_output], ) # demo.launch(debug=True, share=True) # demo.launch(server_name="0.0.0.0", server_port=8082, debug=True, share=False) # demo.queue(max_size=20).launch(show_error=True, share=True) demo.queue(max_size=20).launch(show_error=True) #, share=True, server_port=7888, server_name="0.0.0.0") # share=True # demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True) ########################################################################################################################