supersolar commited on
Commit
42c15b4
·
verified ·
1 Parent(s): 2c27418

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +11 -8
infer.py CHANGED
@@ -53,13 +53,14 @@ def infer_pipe(pipe, image_input, task_name, seed, device):
53
  ).images[0]
54
 
55
  # Post-process the prediction
 
56
  if task_name == 'depth':
57
  output_npy = pred.mean(axis=-1)
58
- output_color = colorize_depth_map(output_npy)
 
59
  else:
60
  output_npy = pred
61
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
62
-
63
  return output_color
64
 
65
  def lotus_video(input_video, task_name, seed, device):
@@ -121,14 +122,15 @@ def lotus_video(input_video, task_name, seed, device):
121
  task_emb=task_emb,
122
  ).images[0]
123
  # Post-process the prediction
 
124
  if task_name == 'depth':
125
  output_npy_g = pred_g.mean(axis=-1)
126
- output_color_g = colorize_depth_map(output_npy_g)
 
127
  else:
128
  output_npy_g = pred_g
129
  output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
130
-
131
- output_g.append(output_color_g)
132
 
133
 
134
  return output_g
@@ -305,13 +307,14 @@ def main():
305
 
306
  # Post-process the prediction
307
  save_file_name = os.path.basename(test_images[i])[:-4]
308
- if args.task_name == 'depth':
 
309
  output_npy = pred.mean(axis=-1)
310
- output_color = colorize_depth_map(output_npy)
 
311
  else:
312
  output_npy = pred
313
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
314
-
315
  output_color.save(os.path.join(output_dir_color, f'{save_file_name}.png'))
316
  np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy)
317
 
 
53
  ).images[0]
54
 
55
  # Post-process the prediction
56
+ # 在 infer_pipe 函数中
57
  if task_name == 'depth':
58
  output_npy = pred.mean(axis=-1)
59
+ # 修改为输出灰度图
60
+ output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
61
  else:
62
  output_npy = pred
63
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
 
64
  return output_color
65
 
66
  def lotus_video(input_video, task_name, seed, device):
 
122
  task_emb=task_emb,
123
  ).images[0]
124
  # Post-process the prediction
125
+ # 在 lotus_video 函数中
126
  if task_name == 'depth':
127
  output_npy_g = pred_g.mean(axis=-1)
128
+ # 修改为输出灰度图
129
+ output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8), mode='L')
130
  else:
131
  output_npy_g = pred_g
132
  output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
133
+ output_g.append(output_color_g)
 
134
 
135
 
136
  return output_g
 
307
 
308
  # Post-process the prediction
309
  save_file_name = os.path.basename(test_images[i])[:-4]
310
+ # infer_pipe 函数中
311
+ if task_name == 'depth':
312
  output_npy = pred.mean(axis=-1)
313
+ # 修改为输出灰度图
314
+ output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
315
  else:
316
  output_npy = pred
317
  output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
 
318
  output_color.save(os.path.join(output_dir_color, f'{save_file_name}.png'))
319
  np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy)
320