Ryukijano commited on
Commit
fed6d1b
Β·
verified Β·
1 Parent(s): f7427dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # timeforge/app.py
2
+ import gradio as gr
3
+ from timeforge.multi_view import MultiViewDiffusion
4
+ from timeforge.vision_llm import VisionLLM
5
+ from timeforge.llama_mesh import LLaMAMesh
6
+ from timeforge.mast3r import MASt3R
7
+ from timeforge.utils import apply_gradient_color
8
+ from timeforge.utils import create_image_grid
9
+
10
+ import torch
11
+
12
+ DESCRIPTION = '''
13
+ <div>
14
+ <h1 style="text-align: center;">TimeForge: Temporal Mesh Synthesis</h1>
15
+ <p> This demo showcases a fusion of state-of-the-art generative models to create 3D representations with temporal variations. </p>
16
+ </div>
17
+ '''
18
+
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # Initialize models
22
+ mv_diff = MultiViewDiffusion(device=DEVICE)
23
+ vllm = VisionLLM(device=DEVICE)
24
+ llama_mesh = LLaMAMesh(device=DEVICE)
25
+ mast3r = MASt3R(device=DEVICE)
26
+
27
+
28
+ @torch.no_grad()
29
+ def process_input(input_prompt, num_views=4, guidance_scale=5, num_inference_steps=30, elevation=0):
30
+ # MultiView Diffusion
31
+ multi_view_images = mv_diff.generate_views(input_prompt, num_views, guidance_scale, num_inference_steps, elevation)
32
+ multi_view_image_grid = create_image_grid(multi_view_images)
33
+ # Vision LLM Analysis
34
+ descriptions = vllm.describe_images(multi_view_images, f"Describe the object in the image, highlight its textures, material, and shape, and it's context, like environment and lighting:")
35
+ refined_past_prompt = descriptions[0] + " ancient, weathered, eroded, original "
36
+ refined_future_prompt = descriptions[0] + " futuristic, advanced, streamlined, evolved, modern "
37
+ # LLaMA-Mesh Generation
38
+ future_mesh = llama_mesh.generate_mesh(refined_future_prompt)
39
+ # MASt3R Point Cloud Generation
40
+ past_point_cloud = mast3r.generate_point_cloud([multi_view_images[0]])
41
+ return multi_view_image_grid, future_mesh, past_point_cloud
42
+
43
+ with gr.Blocks() as demo:
44
+ gr.Markdown(DESCRIPTION)
45
+ with gr.Row():
46
+ with gr.Column(scale=3):
47
+ input_prompt = gr.Textbox(lines=2, placeholder="Enter prompt (e.g., 'A futuristic cyber-temple, once an ancient ruin')", label="Input Prompt")
48
+ num_views = gr.Slider(minimum=2, maximum=8, value=4, step=1, label="Number of Views")
49
+ guidance_scale = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="Guidance Scale")
50
+ num_inference_steps = gr.Slider(minimum=10, maximum=50, value=30, step=1, label="Inference Steps")
51
+ elevation = gr.Slider(minimum=-90, maximum=90, value=0, step=1, label="Elevation")
52
+ run_button = gr.Button("Run")
53
+
54
+ with gr.Column(scale=4):
55
+ multi_view_grid_out = gr.Image(label = "Multi-view Images Output", height=300)
56
+ with gr.Tab("Future Mesh"):
57
+ future_mesh_output = gr.Model3D(label = "Future 3D Mesh output")
58
+ with gr.Tab("Past Point Cloud"):
59
+ past_point_cloud_output = gr.File(label = "Past 3D Point Cloud")
60
+
61
+ run_button.click(
62
+ fn=process_input,
63
+ inputs=[input_prompt, num_views, guidance_scale, num_inference_steps, elevation],
64
+ outputs=[multi_view_grid_out, future_mesh_output, past_point_cloud_output],
65
+ )
66
+ gr.Markdown("## Mesh Visualization (Past)")
67
+ with gr.Row():
68
+ with gr.Column():
69
+ past_mesh_input = gr.Textbox(
70
+ label="Past Point Cloud Input",
71
+ placeholder="Paste your MASt3R file path here...",
72
+ lines=2,
73
+ )
74
+ visualize_past_mesh_button = gr.Button("Visualize Past Mesh")
75
+ with gr.Column():
76
+ past_mesh_output = gr.Model3D(label = "Past 3D Visualization")
77
+
78
+
79
+ visualize_past_mesh_button.click(
80
+ fn=apply_gradient_color,
81
+ inputs=[past_mesh_input],
82
+ outputs=[past_mesh_output]
83
+ )
84
+ gr.Markdown("## Mesh Visualization (Future)")
85
+ with gr.Row():
86
+ with gr.Column():
87
+ future_mesh_input = gr.Textbox(
88
+ label="Future Mesh Input",
89
+ placeholder="Paste your 3D mesh in OBJ format here...",
90
+ lines=2,
91
+ )
92
+ visualize_future_mesh_button = gr.Button("Visualize Future Mesh")
93
+ with gr.Column():
94
+ future_mesh_output_2 = gr.Model3D(label = "Future 3D Visualization")
95
+ visualize_future_mesh_button.click(
96
+ fn=apply_gradient_color,
97
+ inputs=[future_mesh_input],
98
+ outputs=[future_mesh_output_2]
99
+ )
100
+
101
+ demo.launch()