Spaces:
Runtime error
Runtime error
from transformers import ClapModel, ClapProcessor | |
import gradio as gr | |
import torch | |
import torchaudio | |
import os | |
import numpy as np | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Distance, VectorParams | |
from qdrant_client.http import models | |
class ClapSSGradio(): | |
def __init__( | |
self, | |
name, | |
k=10, | |
): | |
self.name = name | |
self.k = k | |
print("Env?!") | |
print(os.getenv('HUGGINGFACE_API_TOKEN')[:2]) | |
self.model = ClapModel.from_pretrained( | |
f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN')) | |
self.tokenizer = ClapProcessor.from_pretrained( | |
f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN')) | |
self.sas_token = os.environ['AZURE_SAS_TOKEN'] | |
self.account_name = 'Audiogen' | |
self.storage_name = 'audiogentrainingdataeun' | |
self._start_qdrant() | |
def _start_qdrant(self): | |
self.client = QdrantClient(url=os.getenv( | |
"QDRANT_URL"), api_key=os.getenv('QDRANT_API_KEY')) | |
# print(self.client.get_collection(collection_name=self.name)) | |
def _embed_query(self, query): | |
inputs = self.tokenizer( | |
query, return_tensors="pt", padding='max_length', max_length=77, truncation=True) | |
return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0] | |
def _similarity_search(self, query): | |
results = self.client.search( | |
collection_name=self.name, | |
query_vector=self._embed_query(query), | |
limit=self.k, | |
score_threshold=0.5, | |
) | |
containers = [result.payload['container'] for result in results] | |
filenames = [result.id for result in results] | |
captions = [result.payload['caption'] for result in results] | |
scores = [result.score for result in results] | |
# print to stdout | |
print(f"\nQuery: {query}\n") | |
for i, (container, filename, caption, score) in enumerate(zip(containers, filenames, captions, scores)): | |
print(f"{i}: {container} - {caption}. Score: {score}") | |
waveforms = self._download_results(containers, filenames) | |
if len(waveforms) == 0: | |
print("\nNo results found") | |
if len(waveforms) < self.k: | |
waveforms.extend([(int(48000), np.zeros((480000, 2))) | |
for _ in range(self.k - len(waveforms))]) | |
return waveforms | |
def _download_results(self, containers: list, filenames: list): | |
# construct url | |
urls = [f"https://{self.storage_name}.blob.core.windows.net/{container}/{file_name}.flac?{self.sas_token}" for container, | |
file_name in zip(containers, filenames)] | |
# make requests | |
waveforms = [] | |
for url in urls: | |
waveform, sample_rate = torchaudio.load(url) | |
waveforms.append(tuple([sample_rate, waveform.numpy().T])) | |
return waveforms | |
def launch(self, share=False): | |
# gradio app structure | |
with gr.Blocks(title='Clap Semantic Search') as ui: | |
with gr.Row(): | |
with gr.Column(variant='panel'): | |
search = gr.Textbox(placeholder='Search Samples') | |
with gr.Column(): | |
audioboxes = [] | |
gr.Markdown("Output") | |
for i in range(self.k): | |
t = gr.components.Audio(label=f"{i}", visible=True) | |
audioboxes.append(t) | |
search.submit(fn=self._similarity_search, inputs=[ | |
search], outputs=audioboxes) | |
ui.launch(share=share) | |
if __name__ == "__main__": | |
app = ClapSSGradio("clap-2") | |
app.launch(share=False) | |