vggt / clean_app.py
JianyuanWang's picture
init
febf487
raw
history blame
6.96 kB
import os
import cv2
import torch
import numpy as np
import gradio as gr
import sys
import os
import socket
import webbrowser
sys.path.append('vggt/')
import shutil
from datetime import datetime
from demo_hf import demo_fn
from omegaconf import DictConfig, OmegaConf
import glob
import gc
import time
from viser_fn import viser_wrapper
def get_free_port():
"""Get a free port using socket."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
port = s.getsockname()[1]
return port
def vggt_demo(
input_video,
input_image,
):
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
debug = False
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
target_dir = f"input_images_{timestamp}"
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
target_dir_images = target_dir + "/images"
os.makedirs(target_dir_images)
if input_video is not None:
if not isinstance(input_video, str):
input_video = input_video["video"]["path"]
cfg_file = "config/base.yaml"
cfg = OmegaConf.load(cfg_file)
if input_image is not None:
input_image = sorted(input_image)
# recon_num = len(input_image)
# Copy files to the new directory
for file_name in input_image:
shutil.copy(file_name, target_dir_images)
elif input_video is not None:
vs = cv2.VideoCapture(input_video)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_rate = 1
frame_interval = int(fps * frame_rate)
video_frame_num = 0
count = 0
while True:
(gotit, frame) = vs.read()
count +=1
if not gotit:
break
if count % frame_interval == 0:
cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
video_frame_num+=1
else:
return None, "Uploading not finished or Incorrect input format"
print(f"Files have been copied to {target_dir_images}")
cfg.SCENE_DIR = target_dir
predictions = demo_fn(cfg)
# Get a free port for viser
viser_port = get_free_port()
# Start viser visualization in a separate thread/process
viser_wrapper(predictions, port=viser_port)
del predictions
gc.collect()
torch.cuda.empty_cache()
print(input_image)
print(input_video)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
return None, viser_port
statue_video = "examples/videos/statue_video.mp4"
apple_video = "examples/videos/apple_video.mp4"
british_museum_video = "examples/videos/british_museum_video.mp4"
cake_video = "examples/videos/cake_video.mp4"
bonsai_video = "examples/videos/bonsai_video.mp4"
face_video = "examples/videos/in2n_face_video.mp4"
counter_video = "examples/videos/in2n_counter_video.mp4"
horns_video = "examples/videos/llff_horns_video.mp4"
person_video = "examples/videos/in2n_person_video.mp4"
flower_video = "examples/videos/llff_flower_video.mp4"
fern_video = "examples/videos/llff_fern_video.mp4"
drums_video = "examples/videos/drums_video.mp4"
kitchen_video = "examples/videos/kitchen_video.mp4"
###########################################################################################
apple_images = glob.glob(f'examples/apple/images/*')
bonsai_images = glob.glob(f'examples/bonsai/images/*')
cake_images = glob.glob(f'examples/cake/images/*')
british_museum_images = glob.glob(f'examples/british_museum/images/*')
face_images = glob.glob(f'examples/in2n_face/images/*')
counter_images = glob.glob(f'examples/in2n_counter/images/*')
horns_images = glob.glob(f'examples/llff_horns/images/*')
person_images = glob.glob(f'examples/in2n_person/images/*')
flower_images = glob.glob(f'examples/llff_flower/images/*')
fern_images = glob.glob(f'examples/llff_fern/images/*')
statue_images = glob.glob(f'examples/statue/images/*')
drums_images = glob.glob(f'examples/drums/images/*')
kitchen_images = glob.glob(f'examples/kitchen/images/*')
###########################################################################################
with gr.Blocks() as demo:
gr.Markdown("""
# πŸ›οΈ VGGT: Visual Geometry Grounded Transformer
<div style="font-size: 16px; line-height: 1.2;">
Alpha version (testing).
</div>
""")
with gr.Row():
with gr.Column(scale=1):
input_video = gr.Video(label="Upload Video", interactive=True)
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
with gr.Column(scale=3):
viser_output = gr.HTML(
label="Viser Visualization",
value='''<div style="height: 520px; border: 1px solid #e0e0e0;
border-radius: 4px; padding: 16px;
display: flex; align-items: center;
justify-content: center">
3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)
</div>'''
)
log_output = gr.Textbox(label="Log")
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1)
clear_btn = gr.ClearButton([input_video, input_images, viser_output, log_output], scale=1) #Modified viser_output
examples = [
[flower_video, flower_images],
[kitchen_video, kitchen_images],
[counter_video, counter_images],
[fern_video, fern_images],
[horns_video, horns_images],
]
def process_example(video, images):
"""Wrapper function to ensure outputs are properly captured"""
model_output, log = vggt_demo(video, images)
viser_url = f"http://localhost:{log}"
print(f"Viser URL: {viser_url}")
# Create the iframe HTML code. Set width and height appropriately.
iframe_code = f'<iframe src="{viser_url}" width="100%" height="520px"></iframe>'
return iframe_code, f"Visualization running at {viser_url}"
gr.Examples(examples=examples,
inputs=[input_video, input_images],
outputs=[viser_output, log_output], # Output to viser_output
fn=process_example, # Use our wrapper function
cache_examples=False,
examples_per_page=50,
)
submit_btn.click(
process_example, # Use the same wrapper function
[input_video, input_images],
[viser_output, log_output], # Output to viser_output
concurrency_limit=1
)
demo.queue(max_size=20).launch(show_error=True, share=True, server_port=7888, server_name="0.0.0.0")