salomonsky commited on
Commit
a024162
·
verified ·
1 Parent(s): b9f491e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -11,14 +11,15 @@ import uuid
11
  import random
12
  from huggingface_hub import hf_hub_download
13
  import spaces
 
 
 
14
 
15
  pipe = StableVideoDiffusionPipeline.from_pretrained(
16
  "vdo/stable-video-diffusion-img2vid-xt-1-1", torch_dtype=torch.float16, variant="fp16"
17
  )
18
  pipe.to("cpu")
19
 
20
- max_64_bit_int = 2**63 - 1
21
-
22
  @spaces.GPU(duration=120)
23
  def sample(
24
  image: Image,
@@ -31,6 +32,7 @@ def sample(
31
  decoding_t: int = 3,
32
  device: str = "cuda",
33
  output_folder: str = "outputs",
 
34
  ):
35
  if image.mode == "RGBA":
36
  image = image.convert("RGB")
@@ -43,12 +45,18 @@ def sample(
43
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
44
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
45
 
46
- frames = pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=0.1, num_frames=25).frames[0]
 
 
 
 
 
47
  export_to_video(frames, video_path, fps=fps_id)
48
  torch.manual_seed(seed)
49
 
50
  return video_path, frames, seed
51
 
 
52
  def resize_image(image, output_size=(1024, 576)):
53
  target_aspect = output_size[0] / output_size[1]
54
  image_aspect = image.width / image.height
@@ -73,6 +81,7 @@ def resize_image(image, output_size=(1024, 576)):
73
  cropped_image = resized_image.crop((left, top, right, bottom))
74
  return cropped_image
75
 
 
76
  with gr.Blocks() as demo:
77
  with gr.Row():
78
  with gr.Column():
@@ -86,9 +95,11 @@ with gr.Blocks() as demo:
86
  with gr.Column():
87
  video = gr.Video(label="Generated video")
88
  gallery = gr.Gallery(label="Generated frames")
 
89
 
90
  image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
91
- generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, gallery, seed], api_name="video")
 
92
 
93
  if __name__ == "__main__":
94
  demo.launch(share=True, show_api=False)
 
11
  import random
12
  from huggingface_hub import hf_hub_download
13
  import spaces
14
+ from tqdm import tqdm
15
+
16
+ max_64_bit_int = 2**63 - 1
17
 
18
  pipe = StableVideoDiffusionPipeline.from_pretrained(
19
  "vdo/stable-video-diffusion-img2vid-xt-1-1", torch_dtype=torch.float16, variant="fp16"
20
  )
21
  pipe.to("cpu")
22
 
 
 
23
  @spaces.GPU(duration=120)
24
  def sample(
25
  image: Image,
 
32
  decoding_t: int = 3,
33
  device: str = "cuda",
34
  output_folder: str = "outputs",
35
+ progress: gr.Progress,
36
  ):
37
  if image.mode == "RGBA":
38
  image = image.convert("RGB")
 
45
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
46
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
47
 
48
+ frames = []
49
+ for i in tqdm(range(25), desc="Generando frames"):
50
+ frame = pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=0.1, num_frames=1).frames[0]
51
+ frames.extend(frame)
52
+ progress.update(i/25)
53
+
54
  export_to_video(frames, video_path, fps=fps_id)
55
  torch.manual_seed(seed)
56
 
57
  return video_path, frames, seed
58
 
59
+
60
  def resize_image(image, output_size=(1024, 576)):
61
  target_aspect = output_size[0] / output_size[1]
62
  image_aspect = image.width / image.height
 
81
  cropped_image = resized_image.crop((left, top, right, bottom))
82
  return cropped_image
83
 
84
+
85
  with gr.Blocks() as demo:
86
  with gr.Row():
87
  with gr.Column():
 
95
  with gr.Column():
96
  video = gr.Video(label="Generated video")
97
  gallery = gr.Gallery(label="Generated frames")
98
+ progress = gr.Progress(label="Progress")
99
 
100
  image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
101
+ generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id, "svd_xt", 0.02, 3, "cuda", "outputs", progress], outputs=[video, gallery, seed, progress], api_name="video")
102
+
103
 
104
  if __name__ == "__main__":
105
  demo.launch(share=True, show_api=False)