xi0v Fabrice-TIERCELIN commited on
Commit
1759c53
Β·
verified Β·
1 Parent(s): 8a3ce8b

Number inference steps (#45)

Browse files

- Number inference steps (90c77ceadccbffbb521dad66b02d8801f21d5a68)


Co-authored-by: Fabrice TIERCELIN <[email protected]>

Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -43,7 +43,8 @@ def animate(
43
  version: str = "auto",
44
  width: int = 1024,
45
  height: int = 576,
46
- motion_control: bool = False
 
47
  ):
48
  start = time.time()
49
 
@@ -56,7 +57,7 @@ def animate(
56
  image_data = image_data.convert("RGB")
57
 
58
  if motion_control:
59
- image_data = [image_data] * 25
60
 
61
  if randomize_seed:
62
  seed = random.randint(0, max_64_bit_int)
@@ -76,7 +77,8 @@ def animate(
76
  decoding_t,
77
  version,
78
  width,
79
- height
 
80
  )
81
 
82
  os.makedirs(output_folder, exist_ok=True)
@@ -133,16 +135,17 @@ def animate_on_gpu(
133
  decoding_t: int = 3,
134
  version: str = "svdxt",
135
  width: int = 1024,
136
- height: int = 576
 
137
  ):
138
  generator = torch.manual_seed(seed)
139
 
140
  if version == "dragnuwa":
141
- return dragnuwaPipe(image_data, width=width, height=height, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
142
  elif version == "svdxt":
143
- return fps25Pipe(image_data, width=width, height=height, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
144
  else:
145
- return fps14Pipe(image_data, width=width, height=height, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
146
 
147
 
148
  def resize_image(image, output_size=(1024, 576)):
@@ -193,7 +196,8 @@ def reset():
193
  "auto",
194
  1024,
195
  576,
196
- False
 
197
  ]
198
 
199
  with gr.Blocks() as demo:
@@ -215,12 +219,13 @@ with gr.Blocks() as demo:
215
  with gr.Accordion("Advanced options", open=False):
216
  width = gr.Slider(label="Width", info="Width of the video", value=1024, minimum=256, maximum=1024, step=8)
217
  height = gr.Slider(label="Height", info="Height of the video", value=576, minimum=256, maximum=576, step=8)
218
- motion_control = gr.Checkbox(label="Motion control (fixed camera)", info="Fix the camera", value=False)
219
  video_format = gr.Radio([["*.mp4", "mp4"], ["*.avi", "avi"], ["*.wmv", "wmv"], ["*.mkv", "mkv"], ["*.mov", "mov"], ["*.gif", "gif"]], label="Video format for result", info="File extention", value="mp4", interactive=True)
220
  frame_format = gr.Radio([["*.webp", "webp"], ["*.png", "png"], ["*.jpeg", "jpeg"], ["*.gif (unanimated)", "gif"], ["*.bmp", "bmp"]], label="Image format for frames", info="File extention", value="webp", interactive=True)
221
  fps_id = gr.Slider(label="Frames per second", info="The length of your video in seconds will be 25/fps", value=25, minimum=5, maximum=30)
222
  motion_bucket_id = gr.Slider(label="Motion bucket id", info="Controls how much motion to add/remove from the image", value=127, minimum=1, maximum=255)
223
  noise_aug_strength = gr.Slider(label="Noise strength", info="The noise to add", value=0.1, minimum=0, maximum=1, step=0.1)
 
224
  decoding_t = gr.Slider(label="Decoding", info="Number of frames decoded at a time; this eats more VRAM; reduce if necessary", value=3, minimum=1, maximum=5, step=1)
225
  version = gr.Radio([["Auto", "auto"], ["πŸƒπŸ»β€β™€οΈ SVD (trained on 14 f/s)", "svd"], ["πŸƒπŸ»β€β™€οΈπŸ’¨ SVD-XT (trained on 25 f/s)", "svdxt"], ["DragNUWA (unstable)", "dragnuwa"]], label="Model", info="Trained model", value="auto", interactive=True)
226
  seed = gr.Slider(label="Seed", value=42, randomize=True, minimum=0, maximum=max_64_bit_int, step=1)
@@ -249,7 +254,8 @@ with gr.Blocks() as demo:
249
  version,
250
  width,
251
  height,
252
- motion_control
 
253
  ], outputs=[
254
  video_output,
255
  gif_output,
@@ -273,16 +279,17 @@ with gr.Blocks() as demo:
273
  version,
274
  width,
275
  height,
276
- motion_control
 
277
  ], queue = False, show_progress = False)
278
 
279
  gr.Examples(
280
  examples=[
281
- ["Examples/Fire.webp", 42, True, 127, 25, 0.1, 3, "mp4", "png", "auto", 1024, 576, False],
282
- ["Examples/Water.png", 42, True, 127, 25, 0.1, 3, "mp4", "png", "auto", 1024, 576, False],
283
- ["Examples/Town.jpeg", 42, True, 127, 25, 0.1, 3, "mp4", "png", "auto", 1024, 576, False]
284
  ],
285
- inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id, noise_aug_strength, decoding_t, video_format, frame_format, version, width, height, motion_control],
286
  outputs=[video_output, gif_output, download_button, gallery, seed, information_msg, reset_btn],
287
  fn=animate,
288
  run_on_click=True,
 
43
  version: str = "auto",
44
  width: int = 1024,
45
  height: int = 576,
46
+ motion_control: bool = False,
47
+ num_inference_steps: int = 25
48
  ):
49
  start = time.time()
50
 
 
57
  image_data = image_data.convert("RGB")
58
 
59
  if motion_control:
60
+ image_data = [image_data] * 2
61
 
62
  if randomize_seed:
63
  seed = random.randint(0, max_64_bit_int)
 
77
  decoding_t,
78
  version,
79
  width,
80
+ height,
81
+ num_inference_steps
82
  )
83
 
84
  os.makedirs(output_folder, exist_ok=True)
 
135
  decoding_t: int = 3,
136
  version: str = "svdxt",
137
  width: int = 1024,
138
+ height: int = 576,
139
+ num_inference_steps: int = 25
140
  ):
141
  generator = torch.manual_seed(seed)
142
 
143
  if version == "dragnuwa":
144
+ return dragnuwaPipe(image_data, width=width, height=height, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25, num_inference_steps=num_inference_steps).frames[0]
145
  elif version == "svdxt":
146
+ return fps25Pipe(image_data, width=width, height=height, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25, num_inference_steps=num_inference_steps).frames[0]
147
  else:
148
+ return fps14Pipe(image_data, width=width, height=height, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25, num_inference_steps=num_inference_steps).frames[0]
149
 
150
 
151
  def resize_image(image, output_size=(1024, 576)):
 
196
  "auto",
197
  1024,
198
  576,
199
+ False,
200
+ 25
201
  ]
202
 
203
  with gr.Blocks() as demo:
 
219
  with gr.Accordion("Advanced options", open=False):
220
  width = gr.Slider(label="Width", info="Width of the video", value=1024, minimum=256, maximum=1024, step=8)
221
  height = gr.Slider(label="Height", info="Height of the video", value=576, minimum=256, maximum=576, step=8)
222
+ motion_control = gr.Checkbox(label="Motion control (experimental)", info="Fix the camera", value=False)
223
  video_format = gr.Radio([["*.mp4", "mp4"], ["*.avi", "avi"], ["*.wmv", "wmv"], ["*.mkv", "mkv"], ["*.mov", "mov"], ["*.gif", "gif"]], label="Video format for result", info="File extention", value="mp4", interactive=True)
224
  frame_format = gr.Radio([["*.webp", "webp"], ["*.png", "png"], ["*.jpeg", "jpeg"], ["*.gif (unanimated)", "gif"], ["*.bmp", "bmp"]], label="Image format for frames", info="File extention", value="webp", interactive=True)
225
  fps_id = gr.Slider(label="Frames per second", info="The length of your video in seconds will be 25/fps", value=25, minimum=5, maximum=30)
226
  motion_bucket_id = gr.Slider(label="Motion bucket id", info="Controls how much motion to add/remove from the image", value=127, minimum=1, maximum=255)
227
  noise_aug_strength = gr.Slider(label="Noise strength", info="The noise to add", value=0.1, minimum=0, maximum=1, step=0.1)
228
+ num_inference_steps = gr.Slider(label="Number inference steps", info="More denoising steps usually lead to a higher quality video at the expense of slower inference", value=25, minimum=1, maximum=100, step=1)
229
  decoding_t = gr.Slider(label="Decoding", info="Number of frames decoded at a time; this eats more VRAM; reduce if necessary", value=3, minimum=1, maximum=5, step=1)
230
  version = gr.Radio([["Auto", "auto"], ["πŸƒπŸ»β€β™€οΈ SVD (trained on 14 f/s)", "svd"], ["πŸƒπŸ»β€β™€οΈπŸ’¨ SVD-XT (trained on 25 f/s)", "svdxt"], ["DragNUWA (unstable)", "dragnuwa"]], label="Model", info="Trained model", value="auto", interactive=True)
231
  seed = gr.Slider(label="Seed", value=42, randomize=True, minimum=0, maximum=max_64_bit_int, step=1)
 
254
  version,
255
  width,
256
  height,
257
+ motion_control,
258
+ num_inference_steps
259
  ], outputs=[
260
  video_output,
261
  gif_output,
 
279
  version,
280
  width,
281
  height,
282
+ motion_control,
283
+ num_inference_steps
284
  ], queue = False, show_progress = False)
285
 
286
  gr.Examples(
287
  examples=[
288
+ ["Examples/Fire.webp", 42, True, 127, 25, 0.1, 3, "mp4", "png", "auto", 1024, 576, False, 25],
289
+ ["Examples/Water.png", 42, True, 127, 25, 0.1, 3, "mp4", "png", "auto", 1024, 576, False, 25],
290
+ ["Examples/Town.jpeg", 42, True, 127, 25, 0.1, 3, "mp4", "png", "auto", 1024, 576, False, 25]
291
  ],
292
+ inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id, noise_aug_strength, decoding_t, video_format, frame_format, version, width, height, motion_control, num_inference_steps],
293
  outputs=[video_output, gif_output, download_button, gallery, seed, information_msg, reset_btn],
294
  fn=animate,
295
  run_on_click=True,