import os import cv2 import torch import numpy as np import gradio as gr import sys import os import socket import webbrowser sys.path.append('vggt/') import shutil from datetime import datetime from demo_hf import demo_fn #, initialize_model from omegaconf import DictConfig, OmegaConf import glob import gc import time from viser_fn import viser_wrapper from gradio_util import demo_predictions_to_glb from hydra.utils import instantiate import spaces # def get_free_port(): # """Get a free port using socket.""" # # return 80 # # return 8080 # # return 10088 # for debugging # # return 7860 # # return 7888 # with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # s.bind(('', 0)) # port = s.getsockname()[1] # return port cfg_file = "config/base.yaml" cfg = OmegaConf.load(cfg_file) vggt_model = instantiate(cfg, _recursive_=False) _VGGT_URL = "https://huggingface.co./facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt" # Reload vggt_model pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL) if "vggt_model" in pretrain_model: model_dict = pretrain_model["vggt_model"] vggt_model.load_state_dict(model_dict, strict=False) else: vggt_model.load_state_dict(pretrain_model, strict=True) # @torch.inference_mode() @spaces.GPU(duration=240) def vggt_demo( input_video, input_image, conf_thres=3.0, frame_filter="all", mask_black_bg=False, ): start_time = time.time() gc.collect() torch.cuda.empty_cache() # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) target_dir_images = target_dir + "/images" os.makedirs(target_dir_images) if input_video is not None: if not isinstance(input_video, str): input_video = input_video["video"]["path"] cfg_file = "config/base.yaml" cfg = OmegaConf.load(cfg_file) if input_image is not None: input_image = sorted(input_image) for file_name in input_image: shutil.copy(file_name, target_dir_images) elif input_video is not None: vs = cv2.VideoCapture(input_video) fps = vs.get(cv2.CAP_PROP_FPS) frame_rate = 1 frame_interval = int(fps * frame_rate) video_frame_num = 0 count = 0 while True: (gotit, frame) = vs.read() count +=1 if not gotit: break if count % frame_interval == 0: cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame) video_frame_num+=1 else: return None, "Uploading not finished or Incorrect input format", None, None all_files = sorted(os.listdir(target_dir_images)) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] # Update frame_filter choices frame_filter_choices = ["All"] + all_files print(f"Files have been copied to {target_dir_images}") cfg.SCENE_DIR = target_dir print("Running demo_fn") with torch.no_grad(): predictions = demo_fn(cfg, vggt_model) predictions["pred_extrinsic_list"] = None print("Saving predictions") prediction_save_path = f"{target_dir}/predictions.npz" np.savez(prediction_save_path, **predictions) glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb" glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg) glbscene.export(file_obj=glbfile) del predictions gc.collect() torch.cuda.empty_cache() print(input_image) print(input_video) end_time = time.time() execution_time = end_time - start_time print(f"Execution time: {execution_time} seconds") # Return None for the 3D vggt_model (since we're using viser) and the viser URL # viser_url = f"Viser visualization is ready at: http://localhost:{viser_port}" # print(viser_url) # Debug print log = "Success. Waiting for visualization." return glbfile, log, target_dir, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True) def update_visualization(target_dir, conf_thres, frame_filter, mask_black_bg): loaded = np.load(f"{target_dir}/predictions.npz", allow_pickle=True) # predictions = np.load(f"{target_dir}/predictions.npz", allow_pickle=True) # predictions["arr_0"] # for key in predictions.files: print(key) predictions = {key: loaded[key] for key in loaded.keys()} glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb" if not os.path.exists(glbfile): glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg) glbscene.export(file_obj=glbfile) return glbfile, "Updating Visualization", target_dir statue_video = "examples/videos/statue_video.mp4" apple_video = "examples/videos/apple_video.mp4" british_museum_video = "examples/videos/british_museum_video.mp4" cake_video = "examples/videos/cake_video.mp4" bonsai_video = "examples/videos/bonsai_video.mp4" face_video = "examples/videos/in2n_face_video.mp4" counter_video = "examples/videos/in2n_counter_video.mp4" horns_video = "examples/videos/llff_horns_video.mp4" person_video = "examples/videos/in2n_person_video.mp4" flower_video = "examples/videos/llff_flower_video.mp4" fern_video = "examples/videos/llff_fern_video.mp4" drums_video = "examples/videos/drums_video.mp4" kitchen_video = "examples/videos/kitchen_video.mp4" ########################################################################################### apple_images = glob.glob(f'examples/apple/images/*') bonsai_images = glob.glob(f'examples/bonsai/images/*') cake_images = glob.glob(f'examples/cake/images/*') british_museum_images = glob.glob(f'examples/british_museum/images/*') face_images = glob.glob(f'examples/in2n_face/images/*') counter_images = glob.glob(f'examples/in2n_counter/images/*') horns_images = glob.glob(f'examples/llff_horns/images/*') person_images = glob.glob(f'examples/in2n_person/images/*') flower_images = glob.glob(f'examples/llff_flower/images/*') fern_images = glob.glob(f'examples/llff_fern/images/*') statue_images = glob.glob(f'examples/statue/images/*') drums_images = glob.glob(f'examples/drums/images/*') kitchen_images = glob.glob(f'examples/kitchen/images/*') ########################################################################################### with gr.Blocks() as demo: gr.Markdown(""" # 🏛️ VGGT: Visual Geometry Grounded Transformer
Alpha version (under active development)
Upload a video or images to create a 3D reconstruction. Once your media appears in the left panel, click the "Reconstruct" button to begin processing.