from transformers import ClapModel, ClapProcessor, AutoFeatureExtractor import gradio as gr import torch import torchaudio import os import numpy as np from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from qdrant_client.http import models import dotenv dotenv.load_dotenv() class ClapSSGradio(): def __init__( self, name, model = "clap-2", k=10, ): self.name = name self.k = k self.model = ClapModel.from_pretrained( f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN')) self.processor = ClapProcessor.from_pretrained( f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN')) self.sas_token = os.environ['AZURE_SAS_TOKEN'] self.account_name = 'Audiogen' self.storage_name = 'audiogentrainingdataeun' self._start_qdrant() def _start_qdrant(self): self.client = QdrantClient(url=os.getenv( "QDRANT_URL"), api_key=os.getenv('QDRANT_API_KEY')) # print(self.client.get_collection(collection_name=self.name)) @torch.no_grad() def _embed_query(self, query, audio_file): if audio_file is not None: waveform, sample_rate = torchaudio.load(audio_file.name) print("Waveform shape:", waveform.shape) waveform = torchaudio.functional.resample( waveform, sample_rate, 48000) print("Resampled waveform shape:", waveform.shape) if waveform.shape[-1] < 480000: waveform = torch.nn.functional.pad( waveform, (0, 48000 - waveform.shape[-1])) elif waveform.shape[-1] > 480000: waveform = waveform[..., :480000] audio_prompt_features = self.processor( audios=waveform.mean(0), return_tensors='pt', sampling_rate=48000 )['input_features'] print("Audio prompt features shape:", audio_prompt_features.shape) e = self.model.get_audio_features( input_features=audio_prompt_features)[0] if any(torch.isnan(e)): raise ValueError("Audio features are NaN") print("Embeddings: ", e.shape) return e else: inputs = self.processor( query, return_tensors="pt", padding='max_length', max_length=77, truncation=True) return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0] def _similarity_search(self, query, threshold, audio_file): results = self.client.search( collection_name=self.name, query_vector=self._embed_query(query, audio_file), limit=self.k, score_threshold=threshold, ) containers = [result.payload['container'] for result in results] filenames = [result.id for result in results] captions = [result.payload['caption'] for result in results] scores = [result.score for result in results] # print to stdout print(f"\nQuery: {query}\n") for i, (container, filename, caption, score) in enumerate(zip(containers, filenames, captions, scores)): print(f"{i}: {container} - {caption}. Score: {score}") waveforms = self._download_results(containers, filenames) if len(waveforms) == 0: print("\nNo results found") if len(waveforms) < self.k: waveforms.extend([(int(48000), np.zeros((480000, 2))) for _ in range(self.k - len(waveforms))]) return waveforms def _download_results(self, containers: list, filenames: list): # construct url urls = [f"https://{self.storage_name}.blob.core.windows.net/snake/{file_name}.flac?{self.sas_token}" for file_name in filenames] # make requests waveforms = [] for url in urls: waveform, sample_rate = torchaudio.load(url) waveforms.append(tuple([sample_rate, waveform.numpy().T])) return waveforms def launch(self, share=False): # gradio app structure with gr.Blocks(title='Clap Semantic Search') as ui: with gr.Row(): with gr.Column(variant='panel'): search = gr.Textbox(placeholder='Search Samples') float_input = gr.Number( label='Similarity threshold [min: 0.1 max: 1]', value=0.5, minimum=0.1, maximum=1) audio_file = gr.File( label='Upload an Audio File', type="file") search_button = gr.Button("Search", label='Search') with gr.Column(): audioboxes = [] gr.Markdown("Output") for i in range(self.k): t = gr.components.Audio(label=f"{i}", visible=True) audioboxes.append(t) search_button.click(fn=self._similarity_search, inputs=[ search, float_input, audio_file], outputs=audioboxes) ui.launch(share=share) if __name__ == "__main__": app = ClapSSGradio("demo") app.launch(share=False)