theOnlyJaco commited on
Commit
6667d8a
·
unverified ·
1 Parent(s): fd3a2ba
Files changed (2) hide show
  1. app.py +117 -0
  2. requirements.txt +89 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ClapModel, ClapProcessor
2
+ import gradio as gr
3
+ import torch
4
+ import torchaudio
5
+ import os
6
+ import numpy as np
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.http.models import Distance, VectorParams
9
+ from qdrant_client.http import models
10
+
11
+
12
+
13
+
14
+
15
+ class ClapSSGradio():
16
+
17
+ def __init__(
18
+ self,
19
+ name,
20
+ k=10,
21
+ ):
22
+
23
+ self.name = name
24
+ self.k = k
25
+
26
+ print("Env?!")
27
+ print(os.getenv('HUGGINGFACE_API_TOKEN')[:2])
28
+
29
+ self.model = ClapModel.from_pretrained(
30
+ f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
31
+ self.tokenizer = ClapProcessor.from_pretrained(
32
+ f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
33
+
34
+ self.sas_token = os.environ['AZURE_SAS_TOKEN']
35
+ self.account_name = 'Audiogen'
36
+ self.storage_name = 'audiogentrainingdataeun'
37
+
38
+ self._start_qdrant()
39
+
40
+ def _start_qdrant(self):
41
+ self.client = QdrantClient(url=os.getenv(
42
+ "QDRANT_URL"), api_key=os.getenv('QDRANT_API_KEY'))
43
+ # print(self.client.get_collection(collection_name=self.name))
44
+
45
+ @torch.no_grad()
46
+ def _embed_query(self, query):
47
+ inputs = self.tokenizer(
48
+ query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
49
+ return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
50
+
51
+ def _similarity_search(self, query):
52
+ results = self.client.search(
53
+ collection_name=self.name,
54
+ query_vector=self._embed_query(query),
55
+ limit=self.k,
56
+ score_threshold=0.5,
57
+ )
58
+
59
+ containers = [result.payload['container'] for result in results]
60
+ filenames = [result.id for result in results]
61
+ captions = [result.payload['caption'] for result in results]
62
+ scores = [result.score for result in results]
63
+
64
+ # print to stdout
65
+ print(f"\nQuery: {query}\n")
66
+ for i, (container, filename, caption, score) in enumerate(zip(containers, filenames, captions, scores)):
67
+ print(f"{i}: {container} - {caption}. Score: {score}")
68
+
69
+ waveforms = self._download_results(containers, filenames)
70
+
71
+ if len(waveforms) == 0:
72
+ print("\nNo results found")
73
+
74
+ if len(waveforms) < self.k:
75
+ waveforms.extend([(int(48000), np.zeros((480000, 2)))
76
+ for _ in range(self.k - len(waveforms))])
77
+
78
+ return waveforms
79
+
80
+ def _download_results(self, containers: list, filenames: list):
81
+
82
+ # construct url
83
+ urls = [f"https://{self.storage_name}.blob.core.windows.net/{container}/{file_name}.flac?{self.sas_token}" for container,
84
+ file_name in zip(containers, filenames)]
85
+
86
+ # make requests
87
+ waveforms = []
88
+ for url in urls:
89
+ waveform, sample_rate = torchaudio.load(url)
90
+ waveforms.append(tuple([sample_rate, waveform.numpy().T]))
91
+
92
+ return waveforms
93
+
94
+ def launch(self, share=False):
95
+ # gradio app structure
96
+ with gr.Blocks(title='Clap Semantic Search') as ui:
97
+
98
+ with gr.Row():
99
+ with gr.Column(variant='panel'):
100
+ search = gr.Textbox(placeholder='Search Samples')
101
+
102
+ with gr.Column():
103
+ audioboxes = []
104
+ gr.Markdown("Output")
105
+ for i in range(self.k):
106
+ t = gr.components.Audio(label=f"{i}", visible=True)
107
+ audioboxes.append(t)
108
+
109
+ search.submit(fn=self._similarity_search, inputs=[
110
+ search], outputs=audioboxes)
111
+
112
+ ui.launch(share=share)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ app = ClapSSGradio("clap-2")
117
+ app.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.1.2
3
+ annotated-types==0.5.0
4
+ anyio==3.7.1
5
+ attrs==23.1.0
6
+ certifi==2023.7.22
7
+ charset-normalizer==3.3.0
8
+ click==8.1.7
9
+ contourpy==1.1.1
10
+ cycler==0.12.0
11
+ exceptiongroup==1.1.3
12
+ fastapi==0.103.2
13
+ ffmpy==0.3.1
14
+ filelock==3.12.4
15
+ fonttools==4.43.0
16
+ fsspec==2023.9.2
17
+ gradio==3.46.1
18
+ gradio_client==0.5.3
19
+ grpcio==1.59.0
20
+ grpcio-tools==1.59.0
21
+ h11==0.14.0
22
+ h2==4.1.0
23
+ hpack==4.0.0
24
+ httpcore==0.18.0
25
+ httpx==0.25.0
26
+ huggingface-hub==0.16.4
27
+ hyperframe==6.0.1
28
+ idna==3.4
29
+ importlib-resources==6.1.0
30
+ Jinja2==3.1.2
31
+ jsonschema==4.19.1
32
+ jsonschema-specifications==2023.7.1
33
+ kiwisolver==1.4.5
34
+ MarkupSafe==2.1.3
35
+ matplotlib==3.8.0
36
+ mpmath==1.3.0
37
+ networkx==3.1
38
+ numpy==1.26.0
39
+ nvidia-cublas-cu12==12.1.3.1
40
+ nvidia-cuda-cupti-cu12==12.1.105
41
+ nvidia-cuda-nvrtc-cu12==12.1.105
42
+ nvidia-cuda-runtime-cu12==12.1.105
43
+ nvidia-cudnn-cu12==8.9.2.26
44
+ nvidia-cufft-cu12==11.0.2.54
45
+ nvidia-curand-cu12==10.3.2.106
46
+ nvidia-cusolver-cu12==11.4.5.107
47
+ nvidia-cusparse-cu12==12.1.0.106
48
+ nvidia-nccl-cu12==2.18.1
49
+ nvidia-nvjitlink-cu12==12.2.140
50
+ nvidia-nvtx-cu12==12.1.105
51
+ orjson==3.9.7
52
+ packaging==23.2
53
+ pandas==2.1.1
54
+ Pillow==10.0.1
55
+ portalocker==2.8.2
56
+ protobuf==4.24.4
57
+ pydantic==2.4.2
58
+ pydantic_core==2.10.1
59
+ pydub==0.25.1
60
+ pyparsing==3.1.1
61
+ python-dateutil==2.8.2
62
+ python-dotenv==1.0.0
63
+ python-multipart==0.0.6
64
+ pytz==2023.3.post1
65
+ PyYAML==6.0.1
66
+ qdrant-client==1.5.4
67
+ referencing==0.30.2
68
+ regex==2023.10.3
69
+ requests==2.31.0
70
+ rpds-py==0.10.3
71
+ safetensors==0.3.3
72
+ semantic-version==2.10.0
73
+ six==1.16.0
74
+ sniffio==1.3.0
75
+ starlette==0.27.0
76
+ sympy==1.12
77
+ tokenizers==0.14.0
78
+ toolz==0.12.0
79
+ torch==2.1.0
80
+ torchaudio==2.1.0
81
+ tqdm==4.66.1
82
+ transformers==4.34.0
83
+ triton==2.1.0
84
+ typing_extensions==4.8.0
85
+ tzdata==2023.3
86
+ urllib3==1.26.17
87
+ uvicorn==0.23.2
88
+ websockets==11.0.3
89
+ zipp==3.17.0