Spaces:
Build error
Build error
# Copyright (c) 2024 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
import os | |
import yaml | |
# import spaces | |
import gradio as gr | |
import librosa | |
from pydub import AudioSegment | |
import soundfile as sf | |
import numpy as np | |
import torch | |
import laion_clap | |
from inference_utils import prepare_tokenizer, prepare_model, inference | |
from data import AudioTextDataProcessor | |
if torch.cuda.is_available(): | |
device = 'cuda:0' | |
else: | |
device = 'cpu' | |
# @spaces.GPU | |
def load_laionclap(): | |
model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').to(device) | |
model.load_ckpt(ckpt='630k-audioset-fusion-best.pt') | |
model.eval() | |
return model | |
def int16_to_float32(x): | |
return (x / 32767.0).astype(np.float32) | |
def float32_to_int16(x): | |
x = np.clip(x, a_min=-1., a_max=1.) | |
return (x * 32767.).astype(np.int16) | |
def load_audio(file_path, target_sr=44100, duration=33.25, start=0.0): | |
if file_path.endswith('.mp3'): | |
audio = AudioSegment.from_file(file_path) | |
if len(audio) > (start + duration) * 1000: | |
audio = audio[start * 1000:(start + duration) * 1000] | |
if audio.frame_rate != target_sr: | |
audio = audio.set_frame_rate(target_sr) | |
if audio.channels > 1: | |
audio = audio.set_channels(1) | |
data = np.array(audio.get_array_of_samples()) | |
if audio.sample_width == 2: | |
data = data.astype(np.float32) / np.iinfo(np.int16).max | |
elif audio.sample_width == 4: | |
data = data.astype(np.float32) / np.iinfo(np.int32).max | |
else: | |
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
else: | |
with sf.SoundFile(file_path) as audio: | |
original_sr = audio.samplerate | |
channels = audio.channels | |
max_frames = int((start + duration) * original_sr) | |
audio.seek(int(start * original_sr)) | |
frames_to_read = min(max_frames, len(audio)) | |
data = audio.read(frames_to_read) | |
if data.max() > 1 or data.min() < -1: | |
data = data / max(abs(data.max()), abs(data.min())) | |
if original_sr != target_sr: | |
if channels == 1: | |
data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
else: | |
data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
else: | |
if channels != 1: | |
data = data.T[0] | |
if data.min() >= 0: | |
data = 2 * data / abs(data.max()) - 1.0 | |
else: | |
data = data / max(abs(data.max()), abs(data.min())) | |
return data | |
# @spaces.GPU | |
def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs): | |
try: | |
data = load_audio(audio_file, target_sr=48000) | |
except Exception as e: | |
print(audio_file, 'unsuccessful due to', e) | |
return [0.0] * len(outputs) | |
audio_data = data.reshape(1, -1) | |
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().to(device) | |
audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True) | |
text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True) | |
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) | |
cos_similarity = cos(audio_embed.repeat(text_embed.shape[0], 1), text_embed) | |
return cos_similarity.squeeze().cpu().numpy() | |
inference_kwargs = { | |
"do_sample": True, | |
"top_k": 50, | |
"top_p": 0.95, | |
"num_return_sequences": 20 | |
} | |
config = yaml.load(open('chat.yaml'), Loader=yaml.FullLoader) | |
clap_config = config['clap_config'] | |
model_config = config['model_config'] | |
text_tokenizer = prepare_tokenizer(model_config) | |
DataProcessor = AudioTextDataProcessor( | |
data_root='./', | |
clap_config=clap_config, | |
tokenizer=text_tokenizer, | |
max_tokens=512, | |
) | |
laionclap_model = load_laionclap() | |
model = prepare_model( | |
model_config=model_config, | |
clap_config=clap_config, | |
checkpoint_path='chat.pt', | |
device=device | |
) | |
# @spaces.GPU | |
def inference_item(name, prompt): | |
item = { | |
'name': str(name), | |
'prefix': 'The task is dialog.', | |
'prompt': str(prompt) | |
} | |
processed_item = DataProcessor.process(item) | |
outputs = inference( | |
model, text_tokenizer, item, processed_item, | |
inference_kwargs, | |
device=device | |
) | |
laionclap_scores = compute_laionclap_text_audio_sim( | |
item["name"], | |
laionclap_model, | |
outputs | |
) | |
outputs_joint = [(output, score) for (output, score) in zip(outputs, laionclap_scores)] | |
outputs_joint.sort(key=lambda x: -x[1]) | |
return outputs_joint[0][0] | |
css = """ | |
a { | |
color: inherit; | |
text-decoration: underline; | |
} | |
.gradio-container { | |
font-family: 'IBM Plex Sans', sans-serif; | |
} | |
.gr-button { | |
color: white; | |
border-color: #000000; | |
background: #000000; | |
} | |
input[type='range'] { | |
accent-color: #000000; | |
} | |
.dark input[type='range'] { | |
accent-color: #dfdfdf; | |
} | |
.container { | |
max-width: 730px; | |
margin: auto; | |
padding-top: 1.5rem; | |
} | |
#gallery { | |
min-height: 22rem; | |
margin-bottom: 15px; | |
margin-left: auto; | |
margin-right: auto; | |
border-bottom-right-radius: .5rem !important; | |
border-bottom-left-radius: .5rem !important; | |
} | |
#gallery>div>.h-full { | |
min-height: 20rem; | |
} | |
.details:hover { | |
text-decoration: underline; | |
} | |
.gr-button { | |
white-space: nowrap; | |
} | |
.gr-button:focus { | |
border-color: rgb(147 197 253 / var(--tw-border-opacity)); | |
outline: none; | |
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); | |
--tw-border-opacity: 1; | |
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); | |
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); | |
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); | |
--tw-ring-opacity: .5; | |
} | |
#advanced-btn { | |
font-size: .7rem !important; | |
line-height: 19px; | |
margin-top: 12px; | |
margin-bottom: 12px; | |
padding: 2px 8px; | |
border-radius: 14px !important; | |
} | |
#advanced-options { | |
margin-bottom: 20px; | |
} | |
.footer { | |
margin-bottom: 45px; | |
margin-top: 35px; | |
text-align: center; | |
border-bottom: 1px solid #e5e5e5; | |
} | |
.footer>p { | |
font-size: .8rem; | |
display: inline-block; | |
padding: 0 10px; | |
transform: translateY(10px); | |
background: white; | |
} | |
.dark .footer { | |
border-color: #303030; | |
} | |
.dark .footer>p { | |
background: #0b0f19; | |
} | |
.acknowledgments h4{ | |
margin: 1.25em 0 .25em 0; | |
font-weight: bold; | |
font-size: 115%; | |
} | |
#container-advanced-btns{ | |
display: flex; | |
flex-wrap: wrap; | |
justify-content: space-between; | |
align-items: center; | |
} | |
.animate-spin { | |
animation: spin 1s linear infinite; | |
} | |
@keyframes spin { | |
from { | |
transform: rotate(0deg); | |
} | |
to { | |
transform: rotate(360deg); | |
} | |
} | |
#share-btn-container { | |
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; | |
margin-top: 10px; | |
margin-left: auto; | |
} | |
#share-btn { | |
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; | |
} | |
#share-btn * { | |
all: unset; | |
} | |
#share-btn-container div:nth-child(-n+2){ | |
width: auto !important; | |
min-height: 0px !important; | |
} | |
#share-btn-container .wrap { | |
display: none !important; | |
} | |
.gr-form{ | |
flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; | |
} | |
#prompt-container{ | |
gap: 0; | |
} | |
#generated_id{ | |
min-height: 700px | |
} | |
#setting_id{ | |
margin-bottom: 12px; | |
text-align: center; | |
font-weight: 900; | |
} | |
""" | |
ui = gr.Blocks(css=css, title="Audio Flamingo - Demo") | |
with ui: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 900px; margin: 0 auto;"> | |
<div | |
style=" | |
display: inline-flex; | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1.5rem; | |
" | |
> | |
<h1 style="font-weight: 700; margin-bottom: 7px; line-height: normal;"> | |
Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities | |
</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 125%"> | |
<a href="https://arxiv.org/abs/2402.01831">[Paper]</a> <a href="https://github.com/NVIDIA/audio-flamingo">[Code]</a> <a href="https://audioflamingo.github.io/">[Demo Website]</a> <a href="https://www.youtube.com/watch?v=ucttuS28RVE">[Demo Video]</a> | |
</p> | |
</div> | |
""" | |
) | |
gr.HTML( | |
""" | |
<div> | |
<h3>Overview</h3> | |
Audio Flamingo is an audio language model that can understand sounds beyond speech. | |
It can also answer questions about the sound in natural language. <br> | |
Examples of questions include: <br> | |
- Can you briefly describe what you hear in this audio? <br> | |
- What is the emotion conveyed in this music? <br> | |
- Where is this audio usually heard? <br> | |
- What place is this music usually played at? <br> | |
</div> | |
""" | |
) | |
name = gr.Textbox( | |
label="Audio file path (choose one from: audio/wav{1--6}.wav)", | |
value="audio/wav1.wav" | |
) | |
prompt = gr.Textbox( | |
label="Instruction", | |
value='Can you briefly describe what you hear in this audio?' | |
) | |
with gr.Row(): | |
play_audio_button = gr.Button("Play Audio") | |
audio_output = gr.Audio(label="Playback") | |
play_audio_button.click(fn=lambda x: x, inputs=name, outputs=audio_output) | |
inference_button = gr.Button("Inference") | |
output_text = gr.Textbox(label="Audio Flamingo output") | |
inference_button.click( | |
fn=inference_item, | |
inputs=[name, prompt], | |
outputs=output_text | |
) | |
ui.queue() | |
ui.launch() | |