File size: 4,407 Bytes
fed6d1b
c3c7356
 
 
 
 
 
3a54a7a
fed6d1b
 
 
 
 
 
 
 
 
 
3a54a7a
fed6d1b
 
 
3a54a7a
fed6d1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a54a7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from multi_view import MultiViewDiffusion
from vision_llm import VisionLLM
from llama_mesh import LLaMAMesh
from mast3r import MASt3R
from utils import apply_gradient_color
from utils import create_image_grid
import os
import torch

DESCRIPTION = '''
<div>
<h1 style="text-align: center;">TimeForge: Temporal Mesh Synthesis</h1>
<p> This demo showcases a fusion of state-of-the-art generative models to create 3D representations with temporal variations. </p>
</div>
'''

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_TOKEN = os.environ.get("HF_TOKEN", None)

# Initialize models
mv_diff = MultiViewDiffusion(device=DEVICE)
vllm = VisionLLM(device=DEVICE, use_auth_token=HF_TOKEN)
llama_mesh = LLaMAMesh(device=DEVICE)
mast3r = MASt3R(device=DEVICE)


@torch.no_grad()
def process_input(input_prompt, num_views=4, guidance_scale=5, num_inference_steps=30, elevation=0):
    # MultiView Diffusion
    multi_view_images = mv_diff.generate_views(input_prompt, num_views, guidance_scale, num_inference_steps, elevation)
    multi_view_image_grid = create_image_grid(multi_view_images)
    # Vision LLM Analysis
    descriptions = vllm.describe_images(multi_view_images, f"Describe the object in the image, highlight its textures, material, and shape, and it's context, like environment and lighting:")
    refined_past_prompt = descriptions[0] + " ancient, weathered, eroded, original "
    refined_future_prompt = descriptions[0] + " futuristic, advanced, streamlined, evolved, modern "
    # LLaMA-Mesh Generation
    future_mesh = llama_mesh.generate_mesh(refined_future_prompt)
    # MASt3R Point Cloud Generation
    past_point_cloud = mast3r.generate_point_cloud([multi_view_images[0]])
    return multi_view_image_grid, future_mesh, past_point_cloud

with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column(scale=3):
            input_prompt = gr.Textbox(lines=2, placeholder="Enter prompt (e.g., 'A futuristic cyber-temple, once an ancient ruin')", label="Input Prompt")
            num_views = gr.Slider(minimum=2, maximum=8, value=4, step=1, label="Number of Views")
            guidance_scale = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="Guidance Scale")
            num_inference_steps = gr.Slider(minimum=10, maximum=50, value=30, step=1, label="Inference Steps")
            elevation = gr.Slider(minimum=-90, maximum=90, value=0, step=1, label="Elevation")
            run_button = gr.Button("Run")

        with gr.Column(scale=4):
             multi_view_grid_out = gr.Image(label = "Multi-view Images Output", height=300)
             with gr.Tab("Future Mesh"):
                future_mesh_output = gr.Model3D(label = "Future 3D Mesh output")
             with gr.Tab("Past Point Cloud"):
                past_point_cloud_output = gr.File(label = "Past 3D Point Cloud")

    run_button.click(
        fn=process_input,
        inputs=[input_prompt, num_views, guidance_scale, num_inference_steps, elevation],
        outputs=[multi_view_grid_out, future_mesh_output, past_point_cloud_output],
    )
    gr.Markdown("## Mesh Visualization (Past)")
    with gr.Row():
        with gr.Column():
             past_mesh_input = gr.Textbox(
                    label="Past Point Cloud Input",
                    placeholder="Paste your MASt3R file path here...",
                    lines=2,
                )
             visualize_past_mesh_button = gr.Button("Visualize Past Mesh")
        with gr.Column():
             past_mesh_output = gr.Model3D(label = "Past 3D Visualization")


    visualize_past_mesh_button.click(
          fn=apply_gradient_color,
          inputs=[past_mesh_input],
           outputs=[past_mesh_output]
        )
    gr.Markdown("## Mesh Visualization (Future)")
    with gr.Row():
        with gr.Column():
             future_mesh_input = gr.Textbox(
                    label="Future Mesh Input",
                    placeholder="Paste your 3D mesh in OBJ format here...",
                    lines=2,
                )
             visualize_future_mesh_button = gr.Button("Visualize Future Mesh")
        with gr.Column():
             future_mesh_output_2 = gr.Model3D(label = "Future 3D Visualization")
    visualize_future_mesh_button.click(
          fn=apply_gradient_color,
          inputs=[future_mesh_input],
           outputs=[future_mesh_output_2]
        )