JeffreyXiang commited on
Commit
db894f7
·
1 Parent(s): 690b53e

fix gradio image compression

Browse files
Files changed (2) hide show
  1. app.py +15 -12
  2. trellis/pipelines/trellis_image_to_3d.py +3 -1
app.py CHANGED
@@ -19,7 +19,7 @@ from trellis.utils import render_utils, postprocessing_utils
19
  MAX_SEED = np.iinfo(np.int32).max
20
 
21
 
22
- def preprocess_image(image: Image.Image) -> Image.Image:
23
  """
24
  Preprocess the input image.
25
 
@@ -27,9 +27,11 @@ def preprocess_image(image: Image.Image) -> Image.Image:
27
  image (Image.Image): The input image.
28
 
29
  Returns:
 
30
  Image.Image: The preprocessed image.
31
  """
32
- return pipeline.preprocess_image(image)
 
33
 
34
 
35
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, model_id: str) -> dict:
@@ -74,12 +76,12 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
74
 
75
 
76
  @spaces.GPU
77
- def image_to_3d(image: Image.Image, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
78
  """
79
  Convert an image to a 3D model.
80
 
81
  Args:
82
- image (Image.Image): The input image.
83
  seed (int): The random seed.
84
  randomize_seed (bool): Whether to randomize the seed.
85
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
@@ -93,9 +95,9 @@ def image_to_3d(image: Image.Image, seed: int, randomize_seed: bool, ss_guidance
93
  """
94
  if randomize_seed:
95
  seed = np.random.randint(0, MAX_SEED)
96
- torch.manual_seed(seed)
97
- outputs = pipeline(
98
- image,
99
  formats=["gaussian", "mesh"],
100
  preprocess_image=False,
101
  sparse_structure_sampler_params={
@@ -181,6 +183,9 @@ with gr.Blocks() as demo:
181
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
182
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
183
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
 
 
 
184
 
185
  # Example images at the bottom of the page
186
  with gr.Row():
@@ -191,23 +196,21 @@ with gr.Blocks() as demo:
191
  ],
192
  inputs=[image_prompt],
193
  fn=lambda image: preprocess_image(image),
194
- outputs=[image_prompt],
195
  run_on_click=True,
196
  examples_per_page=64,
197
  )
198
 
199
- model = gr.State()
200
-
201
  # Handlers
202
  image_prompt.upload(
203
  preprocess_image,
204
  inputs=[image_prompt],
205
- outputs=[image_prompt],
206
  )
207
 
208
  generate_btn.click(
209
  image_to_3d,
210
- inputs=[image_prompt, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
211
  outputs=[model, video_output],
212
  ).then(
213
  activate_button,
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
 
21
 
22
+ def preprocess_image(image: Image.Image) -> Tuple[np.array, Image.Image]:
23
  """
24
  Preprocess the input image.
25
 
 
27
  image (Image.Image): The input image.
28
 
29
  Returns:
30
+ np.array: The preprocessed image.
31
  Image.Image: The preprocessed image.
32
  """
33
+ processed_image = pipeline.preprocess_image(image)
34
+ return np.array(processed_image), processed_image
35
 
36
 
37
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, model_id: str) -> dict:
 
76
 
77
 
78
  @spaces.GPU
79
+ def image_to_3d(image: np.array, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
80
  """
81
  Convert an image to a 3D model.
82
 
83
  Args:
84
+ image (np.array): The input image.
85
  seed (int): The random seed.
86
  randomize_seed (bool): Whether to randomize the seed.
87
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
 
95
  """
96
  if randomize_seed:
97
  seed = np.random.randint(0, MAX_SEED)
98
+ outputs = pipeline.run(
99
+ Image.fromarray(image),
100
+ seed=seed,
101
  formats=["gaussian", "mesh"],
102
  preprocess_image=False,
103
  sparse_structure_sampler_params={
 
183
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
184
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
185
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
186
+
187
+ image = gr.State()
188
+ model = gr.State()
189
 
190
  # Example images at the bottom of the page
191
  with gr.Row():
 
196
  ],
197
  inputs=[image_prompt],
198
  fn=lambda image: preprocess_image(image),
199
+ outputs=[image, image_prompt],
200
  run_on_click=True,
201
  examples_per_page=64,
202
  )
203
 
 
 
204
  # Handlers
205
  image_prompt.upload(
206
  preprocess_image,
207
  inputs=[image_prompt],
208
+ outputs=[image, image_prompt],
209
  )
210
 
211
  generate_btn.click(
212
  image_to_3d,
213
+ inputs=[image, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
214
  outputs=[model, video_output],
215
  ).then(
216
  activate_button,
trellis/pipelines/trellis_image_to_3d.py CHANGED
@@ -254,10 +254,11 @@ class TrellisImageTo3DPipeline(Pipeline):
254
  return slat
255
 
256
  @torch.no_grad()
257
- def __call__(
258
  self,
259
  image: Image.Image,
260
  num_samples: int = 1,
 
261
  sparse_structure_sampler_params: dict = {},
262
  slat_sampler_params: dict = {},
263
  formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
@@ -276,6 +277,7 @@ class TrellisImageTo3DPipeline(Pipeline):
276
  if preprocess_image:
277
  image = self.preprocess_image(image)
278
  cond = self.get_cond([image])
 
279
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
280
  slat = self.sample_slat(cond, coords, slat_sampler_params)
281
  return self.decode_slat(slat, formats)
 
254
  return slat
255
 
256
  @torch.no_grad()
257
+ def run(
258
  self,
259
  image: Image.Image,
260
  num_samples: int = 1,
261
+ seed: int = 42,
262
  sparse_structure_sampler_params: dict = {},
263
  slat_sampler_params: dict = {},
264
  formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
 
277
  if preprocess_image:
278
  image = self.preprocess_image(image)
279
  cond = self.get_cond([image])
280
+ torch.manual_seed(seed)
281
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
282
  slat = self.sample_slat(cond, coords, slat_sampler_params)
283
  return self.decode_slat(slat, formats)