theOnlyJaco's picture
Override container for demo
1a25f7e unverified
raw
history blame
5.27 kB
from transformers import ClapModel, ClapProcessor, AutoFeatureExtractor
import gradio as gr
import torch
import torchaudio
import os
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http import models
import dotenv
dotenv.load_dotenv()
class ClapSSGradio():
def __init__(
self,
name,
model = "clap-2",
k=10,
):
self.name = name
self.k = k
self.model = ClapModel.from_pretrained(
f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
self.processor = ClapProcessor.from_pretrained(
f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
self.sas_token = os.environ['AZURE_SAS_TOKEN']
self.account_name = 'Audiogen'
self.storage_name = 'audiogentrainingdataeun'
self._start_qdrant()
def _start_qdrant(self):
self.client = QdrantClient(url=os.getenv(
"QDRANT_URL"), api_key=os.getenv('QDRANT_API_KEY'))
# print(self.client.get_collection(collection_name=self.name))
@torch.no_grad()
def _embed_query(self, query, audio_file):
if audio_file is not None:
waveform, sample_rate = torchaudio.load(audio_file.name)
print("Waveform shape:", waveform.shape)
waveform = torchaudio.functional.resample(
waveform, sample_rate, 48000)
print("Resampled waveform shape:", waveform.shape)
if waveform.shape[-1] < 480000:
waveform = torch.nn.functional.pad(
waveform, (0, 48000 - waveform.shape[-1]))
elif waveform.shape[-1] > 480000:
waveform = waveform[..., :480000]
audio_prompt_features = self.processor(
audios=waveform.mean(0), return_tensors='pt', sampling_rate=48000
)['input_features']
print("Audio prompt features shape:", audio_prompt_features.shape)
e = self.model.get_audio_features(
input_features=audio_prompt_features)[0]
if any(torch.isnan(e)):
raise ValueError("Audio features are NaN")
print("Embeddings: ", e.shape)
return e
else:
inputs = self.processor(
query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
def _similarity_search(self, query, threshold, audio_file):
results = self.client.search(
collection_name=self.name,
query_vector=self._embed_query(query, audio_file),
limit=self.k,
score_threshold=threshold,
)
containers = [result.payload['container'] for result in results]
filenames = [result.id for result in results]
captions = [result.payload['caption'] for result in results]
scores = [result.score for result in results]
# print to stdout
print(f"\nQuery: {query}\n")
for i, (container, filename, caption, score) in enumerate(zip(containers, filenames, captions, scores)):
print(f"{i}: {container} - {caption}. Score: {score}")
waveforms = self._download_results(containers, filenames)
if len(waveforms) == 0:
print("\nNo results found")
if len(waveforms) < self.k:
waveforms.extend([(int(48000), np.zeros((480000, 2)))
for _ in range(self.k - len(waveforms))])
return waveforms
def _download_results(self, containers: list, filenames: list):
# construct url
urls = [f"https://{self.storage_name}.blob.core.windows.net/snake/{file_name}.flac?{self.sas_token}" for file_name in filenames]
# make requests
waveforms = []
for url in urls:
waveform, sample_rate = torchaudio.load(url)
waveforms.append(tuple([sample_rate, waveform.numpy().T]))
return waveforms
def launch(self, share=False):
# gradio app structure
with gr.Blocks(title='Clap Semantic Search') as ui:
with gr.Row():
with gr.Column(variant='panel'):
search = gr.Textbox(placeholder='Search Samples')
float_input = gr.Number(
label='Similarity threshold [min: 0.1 max: 1]', value=0.5, minimum=0.1, maximum=1)
audio_file = gr.File(
label='Upload an Audio File', type="file")
search_button = gr.Button("Search", label='Search')
with gr.Column():
audioboxes = []
gr.Markdown("Output")
for i in range(self.k):
t = gr.components.Audio(label=f"{i}", visible=True)
audioboxes.append(t)
search_button.click(fn=self._similarity_search, inputs=[
search, float_input, audio_file], outputs=audioboxes)
ui.launch(share=share)
if __name__ == "__main__":
app = ClapSSGradio("demo")
app.launch(share=False)