supersolar commited on
Commit
2c27418
·
verified ·
1 Parent(s): f8a8adc

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +10 -39
infer.py CHANGED
@@ -65,25 +65,17 @@ def infer_pipe(pipe, image_input, task_name, seed, device):
65
  def lotus_video(input_video, task_name, seed, device):
66
  if task_name == 'depth':
67
  model_g = 'jingheya/lotus-depth-g-v1-0'
68
- model_d = 'jingheya/lotus-depth-d-v1-0'
69
  else:
70
  model_g = 'jingheya/lotus-normal-g-v1-0'
71
- model_d = 'jingheya/lotus-normal-d-v1-0'
72
 
73
  dtype = torch.float16
74
  pipe_g = LotusGPipeline.from_pretrained(
75
  model_g,
76
  torch_dtype=dtype,
77
  )
78
- pipe_d = LotusDPipeline.from_pretrained(
79
- model_d,
80
- torch_dtype=dtype,
81
- )
82
  pipe_g.to(device)
83
- pipe_d.to(device)
84
  pipe_g.set_progress_bar_config(disable=True)
85
- pipe_d.set_progress_bar_config(disable=True)
86
- logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
87
 
88
  # load the video and split it into frames
89
  cap = cv2.VideoCapture(input_video)
@@ -105,7 +97,6 @@ def lotus_video(input_video, task_name, seed, device):
105
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
106
 
107
  output_g = []
108
- output_d = []
109
  for frame in frames:
110
  if torch.backends.mps.is_available():
111
  autocast_ctx = nullcontext()
@@ -129,59 +120,39 @@ def lotus_video(input_video, task_name, seed, device):
129
  timesteps=[999],
130
  task_emb=task_emb,
131
  ).images[0]
132
- pred_d = pipe_d(
133
- rgb_in=test_image,
134
- prompt='',
135
- num_inference_steps=1,
136
- generator=generator,
137
- # guidance_scale=0,
138
- output_type='np',
139
- timesteps=[999],
140
- task_emb=task_emb,
141
- ).images[0]
142
-
143
  # Post-process the prediction
144
  if task_name == 'depth':
145
  output_npy_g = pred_g.mean(axis=-1)
146
  output_color_g = colorize_depth_map(output_npy_g)
147
- output_npy_d = pred_d.mean(axis=-1)
148
- output_color_d = colorize_depth_map(output_npy_d)
149
  else:
150
  output_npy_g = pred_g
151
  output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
152
- output_npy_d = pred_d
153
- output_color_d = Image.fromarray((output_npy_d * 255).astype(np.uint8))
154
-
155
  output_g.append(output_color_g)
156
- output_d.append(output_color_d)
157
 
158
- return output_g, output_d
159
 
160
  def lotus(image_input, task_name, seed, device):
161
  if task_name == 'depth':
162
  model_g = 'jingheya/lotus-depth-g-v1-0'
163
- model_d = 'jingheya/lotus-depth-d-v1-1'
164
  else:
165
  model_g = 'jingheya/lotus-normal-g-v1-0'
166
- model_d = 'jingheya/lotus-normal-d-v1-0'
167
 
168
  dtype = torch.float16
169
  pipe_g = LotusGPipeline.from_pretrained(
170
  model_g,
171
  torch_dtype=dtype,
172
  )
173
- pipe_d = LotusDPipeline.from_pretrained(
174
- model_d,
175
- torch_dtype=dtype,
176
- )
177
  pipe_g.to(device)
178
- pipe_d.to(device)
179
  pipe_g.set_progress_bar_config(disable=True)
180
- pipe_d.set_progress_bar_config(disable=True)
181
- logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
182
  output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
183
- output_d = infer_pipe(pipe_d, image_input, task_name, seed, device)
184
- return output_g, output_d
185
 
186
  def parse_args():
187
  '''Set the Args'''
 
65
  def lotus_video(input_video, task_name, seed, device):
66
  if task_name == 'depth':
67
  model_g = 'jingheya/lotus-depth-g-v1-0'
 
68
  else:
69
  model_g = 'jingheya/lotus-normal-g-v1-0'
 
70
 
71
  dtype = torch.float16
72
  pipe_g = LotusGPipeline.from_pretrained(
73
  model_g,
74
  torch_dtype=dtype,
75
  )
 
 
 
 
76
  pipe_g.to(device)
 
77
  pipe_g.set_progress_bar_config(disable=True)
78
+ logging.info(f"Successfully loading pipeline from {model_g}.")
 
79
 
80
  # load the video and split it into frames
81
  cap = cv2.VideoCapture(input_video)
 
97
  task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
98
 
99
  output_g = []
 
100
  for frame in frames:
101
  if torch.backends.mps.is_available():
102
  autocast_ctx = nullcontext()
 
120
  timesteps=[999],
121
  task_emb=task_emb,
122
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
123
  # Post-process the prediction
124
  if task_name == 'depth':
125
  output_npy_g = pred_g.mean(axis=-1)
126
  output_color_g = colorize_depth_map(output_npy_g)
 
 
127
  else:
128
  output_npy_g = pred_g
129
  output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
130
+
 
 
131
  output_g.append(output_color_g)
132
+
133
 
134
+ return output_g
135
 
136
  def lotus(image_input, task_name, seed, device):
137
  if task_name == 'depth':
138
  model_g = 'jingheya/lotus-depth-g-v1-0'
 
139
  else:
140
  model_g = 'jingheya/lotus-normal-g-v1-0'
 
141
 
142
  dtype = torch.float16
143
  pipe_g = LotusGPipeline.from_pretrained(
144
  model_g,
145
  torch_dtype=dtype,
146
  )
147
+
 
 
 
148
  pipe_g.to(device)
149
+
150
  pipe_g.set_progress_bar_config(disable=True)
151
+
152
+ logging.info(f"Successfully loading pipeline from {model_g}.")
153
  output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
154
+
155
+ return output_g
156
 
157
  def parse_args():
158
  '''Set the Args'''