Spaces:
Runtime error
Runtime error
File size: 5,188 Bytes
46ff99b 7e40a31 46ff99b 7e40a31 2879448 7e40a31 cfb23a7 7e40a31 46ff99b 7e40a31 7c3177c 7e40a31 9a68e0a 8575490 1b3547a 7e40a31 192c72b 1add98c b7fdbe0 1add98c b7fdbe0 1add98c 9309c87 75c374e 7e40a31 46ff99b 7e40a31 5e5405f 7e40a31 46ff99b 7e40a31 46ff99b 7e40a31 46ff99b 7e40a31 46ff99b 25f002c 7e40a31 75c374e 11106b2 9cb6cb1 75c374e 98d0bf8 0eb00d8 9309c87 46ff99b 6d31513 c6dfc2e 1add98c 1b3547a 46ff99b 7e40a31 9a68e0a 7e40a31 46ff99b a1a5a34 7e40a31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from typing import Tuple, Union
import gradio as gr
import numpy as np
import see2sound
import spaces
import torch
import yaml
import os
from huggingface_hub import snapshot_download
model_id = "rishitdagli/see-2-sound"
base_path = snapshot_download(repo_id=model_id)
with open("config.yaml", "r") as file:
data = yaml.safe_load(file)
data_str = yaml.dump(data)
updated_data_str = data_str.replace("checkpoints", base_path)
updated_data = yaml.safe_load(updated_data_str)
with open("config.yaml", "w") as file:
yaml.safe_dump(updated_data, file)
model = see2sound.See2Sound(config_path="config.yaml")
model.setup()
CACHE_DIR = "gradio_cached_examples"
example_mapping = {
1: os.path.join(CACHE_DIR, "1"),
2: os.path.join(CACHE_DIR, "2"),
3: os.path.join(CACHE_DIR, "3")
}
#for local cache
def load_cached_example_outputs(example_index: int) -> Tuple[str, str]:
cached_dir = os.path.join(CACHE_DIR, str(example_index)) # Use the example index to find the directory
cached_image_path = os.path.join(cached_dir, "processed_image.png")
cached_audio_path = os.path.join(cached_dir, "audio.wav")
# Ensure cached files exist
if os.path.exists(cached_image_path) and os.path.exists(cached_audio_path):
return cached_image_path, cached_audio_path
else:
raise FileNotFoundError(f"Cached outputs not found for example {example_index}")
# Function to handle the example click, based on index
def on_example_click(index: int, *args, **kwargs):
return load_cached_example_outputs(index)
# # to handle the example click, it now accepts arbitrary arguments
# def on_example_click(*args, **kwargs):
# return load_cached_example_outputs(1) # Always load example 1 for now
@spaces.GPU(duration=280)
@torch.no_grad()
def process_image(
image: str, num_audios: int, prompt: Union[str, None], steps: Union[int, None]
) -> Tuple[str, str]:
model.run(
path=image,
output_path="audio.wav",
num_audios=num_audios,
prompt=prompt,
steps=steps,
)
return image, "audio.wav"
description_text = """# SEE-2-SOUND 🔊 Demo
Official demo for *SEE-2-SOUND 🔊: Zero-Shot Spatial Environment-to-Spatial Sound*.
Please refer to our [paper](https://arxiv.org/abs/2406.06612), [project page](https://see2sound.github.io/), or [github](https://github.com/see2sound/see2sound) for more details.
> Note: You should make sure that your hardware supports spatial audio.
This demo allows you to generate spatial audio given an image. Upload an image (with an optional text prompt in the advanced settings) to geenrate spatial audio to accompany the image.
"""
css = """
h1 {
text-align: center;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(description_text)
with gr.Row():
with gr.Column():
image = gr.Image(
label="Select an image", sources=["upload", "webcam"], type="filepath"
)
with gr.Accordion("Advanced Settings", open=False):
steps = gr.Slider(
label="Diffusion Steps", minimum=1, maximum=1000, step=1, value=500
)
prompt = gr.Text(
label="Prompt",
show_label=True,
max_lines=1,
placeholder="Enter your prompt",
container=True,
)
num_audios = gr.Slider(
label="Number of Audios", minimum=1, maximum=10, step=1, value=3
)
submit_button = gr.Button("Submit")
with gr.Column():
processed_image = gr.Image(label="Processed Image")
processed_video = gr.Video(label="Processed Video", visible=False) # Initially hidden
generated_audio = gr.Audio(
label="Generated Audio",
show_download_button=True,
show_share_button=True,
waveform_options=gr.WaveformOptions(
waveform_color="#01C6FF",
waveform_progress_color="#0066B4",
show_controls=True,
),
)
# Example inputs, the last two are videos
example = [
["examples/1.png", 3, "A scenic mountain view", 500],
["examples/2.png", 2, "A forest with birds", 500],
["examples/3.png", 1, "A crowded city", 500]
]
def update_examples(index):
example_index = int(index) # Convert index to integer for use
return load_cached_example_outputs(example_index)
gr.Examples(
examples=example, # Example inputs
inputs=[image, num_audios, prompt, steps],
outputs=[processed_image, generated_audio],
cache_examples=True, # Cache examples to avoid running the model
fn=lambda *args: on_example_click(int(args[0].split('/')[-1][0])) # Extract example index from image path
)
gr.on(
triggers=[submit_button.click],
fn=process_image,
inputs=[image, num_audios, prompt, steps],
outputs=[processed_image, generated_audio],
)
if __name__ == "__main__":
demo.launch(debug=True)
|