ZhifengKong's picture
update
a962d22
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import yaml
# import spaces
import gradio as gr
import librosa
from pydub import AudioSegment
import soundfile as sf
import numpy as np
import torch
import laion_clap
from inference_utils import prepare_tokenizer, prepare_model, inference
from data import AudioTextDataProcessor
if torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu'
# @spaces.GPU
def load_laionclap():
model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').to(device)
model.load_ckpt(ckpt='630k-audioset-fusion-best.pt')
model.eval()
return model
def int16_to_float32(x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
def load_audio(file_path, target_sr=44100, duration=33.25, start=0.0):
if file_path.endswith('.mp3'):
audio = AudioSegment.from_file(file_path)
if len(audio) > (start + duration) * 1000:
audio = audio[start * 1000:(start + duration) * 1000]
if audio.frame_rate != target_sr:
audio = audio.set_frame_rate(target_sr)
if audio.channels > 1:
audio = audio.set_channels(1)
data = np.array(audio.get_array_of_samples())
if audio.sample_width == 2:
data = data.astype(np.float32) / np.iinfo(np.int16).max
elif audio.sample_width == 4:
data = data.astype(np.float32) / np.iinfo(np.int32).max
else:
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
else:
with sf.SoundFile(file_path) as audio:
original_sr = audio.samplerate
channels = audio.channels
max_frames = int((start + duration) * original_sr)
audio.seek(int(start * original_sr))
frames_to_read = min(max_frames, len(audio))
data = audio.read(frames_to_read)
if data.max() > 1 or data.min() < -1:
data = data / max(abs(data.max()), abs(data.min()))
if original_sr != target_sr:
if channels == 1:
data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
else:
data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
else:
if channels != 1:
data = data.T[0]
if data.min() >= 0:
data = 2 * data / abs(data.max()) - 1.0
else:
data = data / max(abs(data.max()), abs(data.min()))
return data
# @spaces.GPU
@torch.no_grad()
def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs):
try:
data = load_audio(audio_file, target_sr=48000)
except Exception as e:
print(audio_file, 'unsuccessful due to', e)
return [0.0] * len(outputs)
audio_data = data.reshape(1, -1)
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().to(device)
audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True)
text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True)
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
cos_similarity = cos(audio_embed.repeat(text_embed.shape[0], 1), text_embed)
return cos_similarity.squeeze().cpu().numpy()
inference_kwargs = {
"do_sample": True,
"top_k": 50,
"top_p": 0.95,
"num_return_sequences": 20
}
config = yaml.load(open('chat.yaml'), Loader=yaml.FullLoader)
clap_config = config['clap_config']
model_config = config['model_config']
text_tokenizer = prepare_tokenizer(model_config)
DataProcessor = AudioTextDataProcessor(
data_root='./',
clap_config=clap_config,
tokenizer=text_tokenizer,
max_tokens=512,
)
laionclap_model = load_laionclap()
model = prepare_model(
model_config=model_config,
clap_config=clap_config,
checkpoint_path='chat.pt',
device=device
)
# @spaces.GPU
def inference_item(name, prompt):
item = {
'name': str(name),
'prefix': 'The task is dialog.',
'prompt': str(prompt)
}
processed_item = DataProcessor.process(item)
outputs = inference(
model, text_tokenizer, item, processed_item,
inference_kwargs,
device=device
)
laionclap_scores = compute_laionclap_text_audio_sim(
item["name"],
laionclap_model,
outputs
)
outputs_joint = [(output, score) for (output, score) in zip(outputs, laionclap_scores)]
outputs_joint.sort(key=lambda x: -x[1])
return outputs_joint[0][0]
css = """
a {
color: inherit;
text-decoration: underline;
}
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-color: #000000;
background: #000000;
}
input[type='range'] {
accent-color: #000000;
}
.dark input[type='range'] {
accent-color: #dfdfdf;
}
.container {
max-width: 730px;
margin: auto;
padding-top: 1.5rem;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
.details:hover {
text-decoration: underline;
}
.gr-button {
white-space: nowrap;
}
.gr-button:focus {
border-color: rgb(147 197 253 / var(--tw-border-opacity));
outline: none;
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
--tw-border-opacity: 1;
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
--tw-ring-opacity: .5;
}
#advanced-btn {
font-size: .7rem !important;
line-height: 19px;
margin-top: 12px;
margin-bottom: 12px;
padding: 2px 8px;
border-radius: 14px !important;
}
#advanced-options {
margin-bottom: 20px;
}
.footer {
margin-bottom: 45px;
margin-top: 35px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
.acknowledgments h4{
margin: 1.25em 0 .25em 0;
font-weight: bold;
font-size: 115%;
}
#container-advanced-btns{
display: flex;
flex-wrap: wrap;
justify-content: space-between;
align-items: center;
}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
margin-top: 10px;
margin-left: auto;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
.gr-form{
flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
}
#prompt-container{
gap: 0;
}
#generated_id{
min-height: 700px
}
#setting_id{
margin-bottom: 12px;
text-align: center;
font-weight: 900;
}
"""
ui = gr.Blocks(css=css, title="Audio Flamingo - Demo")
with ui:
gr.HTML(
"""
<div style="text-align: center; max-width: 900px; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.5rem;
"
>
<h1 style="font-weight: 700; margin-bottom: 7px; line-height: normal;">
Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities
</h1>
</div>
<p style="margin-bottom: 10px; font-size: 125%">
<a href="https://arxiv.org/abs/2402.01831">[Paper]</a> <a href="https://github.com/NVIDIA/audio-flamingo">[Code]</a> <a href="https://audioflamingo.github.io/">[Demo Website]</a> <a href="https://www.youtube.com/watch?v=ucttuS28RVE">[Demo Video]</a>
</p>
</div>
"""
)
gr.HTML(
"""
<div>
<h3>Overview</h3>
Audio Flamingo is an audio language model that can understand sounds beyond speech.
It can also answer questions about the sound in natural language. <br>
Examples of questions include: <br>
- Can you briefly describe what you hear in this audio? <br>
- What is the emotion conveyed in this music? <br>
- Where is this audio usually heard? <br>
- What place is this music usually played at? <br>
</div>
"""
)
name = gr.Textbox(
label="Audio file path (choose one from: audio/wav{1--6}.wav)",
value="audio/wav1.wav"
)
prompt = gr.Textbox(
label="Instruction",
value='Can you briefly describe what you hear in this audio?'
)
with gr.Row():
play_audio_button = gr.Button("Play Audio")
audio_output = gr.Audio(label="Playback")
play_audio_button.click(fn=lambda x: x, inputs=name, outputs=audio_output)
inference_button = gr.Button("Inference")
output_text = gr.Textbox(label="Audio Flamingo output")
inference_button.click(
fn=inference_item,
inputs=[name, prompt],
outputs=output_text
)
ui.queue()
ui.launch()