jadechoghari commited on
Commit
9a68e0a
ยท
verified ยท
1 Parent(s): 372ade2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -93
app.py CHANGED
@@ -1,79 +1,28 @@
1
  from typing import Tuple, Union
2
  import gradio as gr
3
- import numpy as np
4
- import see2sound
5
- import spaces
6
- import torch
7
- import yaml
8
  import os
9
- from huggingface_hub import snapshot_download
10
  from PIL import Image
11
 
12
- model_id = "rishitdagli/see-2-sound"
13
- base_path = snapshot_download(repo_id=model_id)
14
-
15
- # load and update the configuration
16
- with open("config.yaml", "r") as file:
17
- data = yaml.safe_load(file)
18
- data_str = yaml.dump(data)
19
- updated_data_str = data_str.replace("checkpoints", base_path)
20
- updated_data = yaml.safe_load(updated_data_str)
21
- with open("config.yaml", "w") as file:
22
- yaml.safe_dump(updated_data, file)
23
-
24
- model = see2sound.See2Sound(config_path="config.yaml")
25
- model.setup()
26
-
27
  CACHE_DIR = "gradio_cached_examples"
28
 
29
- # function to create cached output directory
30
- def create_cache_dir(image_path):
31
- image_name = os.path.basename(image_path).split('.')[0]
32
- cached_dir = os.path.join(CACHE_DIR, image_name)
33
- os.makedirs(cached_dir, exist_ok=True)
34
- return cached_dir
35
-
36
- # fn to process image and cache outputs
37
- @spaces.GPU(duration=280)
38
- @torch.no_grad()
39
- def process_image(
40
- image: str, num_audios: int, prompt: Union[str, None], steps: Union[int, None]
41
- ) -> Tuple[str, str]:
42
- cached_dir = create_cache_dir(image)
43
  cached_image_path = os.path.join(cached_dir, "processed_image.png")
44
  cached_audio_path = os.path.join(cached_dir, "audio.wav")
45
 
46
- # check if cached outputs exist, if yes, return them
47
  if os.path.exists(cached_image_path) and os.path.exists(cached_audio_path):
48
  return cached_image_path, cached_audio_path
49
-
50
- # run the model if outputs are not cached
51
- model.run(
52
- path=image,
53
- output_path=cached_audio_path, # Save audio in cache directory
54
- num_audios=num_audios,
55
- prompt=prompt,
56
- steps=steps,
57
- )
58
-
59
- # save the processed image to the cache directory (use original image or any transformations)
60
- processed_image = Image.open(image) # Assuming image is a file path
61
- processed_image.save(cached_image_path)
62
-
63
- return cached_image_path, cached_audio_path
64
-
65
 
66
  description_text = """# SEE-2-SOUND ๐Ÿ”Š Demo
67
 
68
  Official demo for *SEE-2-SOUND ๐Ÿ”Š: Zero-Shot Spatial Environment-to-Spatial Sound*.
69
- Please refer to our [paper](https://arxiv.org/abs/2406.06612), [project page](https://see2sound.github.io/), or [github](https://github.com/see2sound/see2sound) for more details.
70
- > Note: You should make sure that your hardware supports spatial audio.
71
  """
72
 
73
  css = """
74
- h1 {
75
- text-align: center;
76
- }
77
  """
78
 
79
  with gr.Blocks(css=css) as demo:
@@ -81,56 +30,38 @@ with gr.Blocks(css=css) as demo:
81
 
82
  with gr.Row():
83
  with gr.Column():
84
- image = gr.Image(
85
- label="Select an image", sources=["upload", "webcam"], type="filepath"
86
- )
87
 
88
  with gr.Accordion("Advanced Settings", open=False):
89
- steps = gr.Slider(
90
- label="Diffusion Steps", minimum=1, maximum=1000, step=1, value=500
91
- )
92
- prompt = gr.Text(
93
- label="Prompt",
94
- show_label=True,
95
- max_lines=1,
96
- placeholder="Enter your prompt",
97
- container=True,
98
- )
99
- num_audios = gr.Slider(
100
- label="Number of Audios", minimum=1, maximum=10, step=1, value=3
101
- )
102
 
103
  submit_button = gr.Button("Submit")
104
 
105
  with gr.Column():
106
  processed_image = gr.Image(label="Processed Image")
107
- generated_audio = gr.Audio(
108
- label="Generated Audio",
109
- show_download_button=True,
110
- show_share_button=True,
111
- waveform_options=gr.WaveformOptions(
112
- waveform_color="#01C6FF",
113
- waveform_progress_color="#0066B4",
114
- show_controls=True,
115
- ),
116
- )
117
-
118
- # load examples with manually cached outputs
119
  gr.Examples(
120
- examples=[
121
- ["examples/1.png", 3, "A scenic mountain view", 500]
122
- ],
123
  inputs=[image, num_audios, prompt, steps],
124
  outputs=[processed_image, generated_audio],
125
- cache_examples="lazy", # Cache outputs as users interact
126
- fn=process_image
127
  )
128
 
 
129
  submit_button.click(
130
- process_image,
131
- inputs=[image, num_audios, prompt, steps],
132
  outputs=[processed_image, generated_audio]
133
  )
134
 
135
  if __name__ == "__main__":
136
- demo.launch()
 
1
  from typing import Tuple, Union
2
  import gradio as gr
 
 
 
 
 
3
  import os
 
4
  from PIL import Image
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  CACHE_DIR = "gradio_cached_examples"
7
 
8
+
9
+ def load_cached_example_outputs(example_index: int) -> Tuple[str, str]:
10
+ cached_dir = os.path.join(CACHE_DIR, str(example_index)) # Use the example index to find the directory
 
 
 
 
 
 
 
 
 
 
 
11
  cached_image_path = os.path.join(cached_dir, "processed_image.png")
12
  cached_audio_path = os.path.join(cached_dir, "audio.wav")
13
 
 
14
  if os.path.exists(cached_image_path) and os.path.exists(cached_audio_path):
15
  return cached_image_path, cached_audio_path
16
+ else:
17
+ raise FileNotFoundError(f"Cached outputs not found for example {example_index}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  description_text = """# SEE-2-SOUND ๐Ÿ”Š Demo
20
 
21
  Official demo for *SEE-2-SOUND ๐Ÿ”Š: Zero-Shot Spatial Environment-to-Spatial Sound*.
 
 
22
  """
23
 
24
  css = """
25
+ h1 { text-align: center; }
 
 
26
  """
27
 
28
  with gr.Blocks(css=css) as demo:
 
30
 
31
  with gr.Row():
32
  with gr.Column():
33
+ image = gr.Image(label="Select an image", sources=["upload", "webcam"], type="filepath")
 
 
34
 
35
  with gr.Accordion("Advanced Settings", open=False):
36
+ steps = gr.Slider(label="Diffusion Steps", minimum=1, maximum=1000, step=1, value=500)
37
+ prompt = gr.Text(label="Prompt", max_lines=1, placeholder="Enter your prompt")
38
+ num_audios = gr.Slider(label="Number of Audios", minimum=1, maximum=10, step=1, value=3)
 
 
 
 
 
 
 
 
 
 
39
 
40
  submit_button = gr.Button("Submit")
41
 
42
  with gr.Column():
43
  processed_image = gr.Image(label="Processed Image")
44
+ generated_audio = gr.Audio(label="Generated Audio", show_download_button=True)
45
+
46
+
47
+ def on_example_click(example_input):
48
+ return load_cached_example_outputs(1) # Always use example 1 for now
49
+
50
+
 
 
 
 
 
51
  gr.Examples(
52
+ examples=[["examples/1.png", 3, "A scenic mountain view", 500]], # Example input
 
 
53
  inputs=[image, num_audios, prompt, steps],
54
  outputs=[processed_image, generated_audio],
55
+ cache_examples=True, # Cache examples to avoid running the model
56
+ fn=on_example_click # Load the cached output when the example is clicked
57
  )
58
 
59
+
60
  submit_button.click(
61
+ fn=on_example_click,
62
+ inputs=[image, num_audios, prompt, steps],
63
  outputs=[processed_image, generated_audio]
64
  )
65
 
66
  if __name__ == "__main__":
67
+ demo.launch()