leadr64 commited on
Commit
0d3e4a3
·
1 Parent(s): 98c8904

Ajouter le script Gradio et les dépendances

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -1,29 +1,29 @@
1
- import gradio as gr
2
- import laion_clap
3
- from qdrant_client import QdrantClient
4
  import os
5
 
6
- # Utilisez les variables d'environnement pour la configuration
7
- QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost')
8
- QDRANT_PORT = int(os.getenv('QDRANT_PORT', 6333))
9
 
10
- # Connexion à Qdrant
11
- client = QdrantClient(QDRANT_HOST, port=QDRANT_PORT)
12
  print("[INFO] Client created...")
13
 
14
- # Charger le modèle
15
  print("[INFO] Loading the model...")
16
- model_name = "laion/larger_clap_music"
17
- model = laion_clap.CLAP_Module(enable_fusion=False)
18
- model.load_ckpt() # télécharger le checkpoint préentraîné par défaut
19
 
20
- # Interface Gradio
21
  max_results = 10
22
 
 
23
  def sound_search(query):
24
- text_embed = model.get_text_embedding([query, ''])[0] # trick because can't accept singleton
 
 
25
  hits = client.search(
26
- collection_name="demo_db7",
27
  query_vector=text_embed,
28
  limit=max_results,
29
  )
@@ -34,12 +34,13 @@ def sound_search(query):
34
  for hit in hits
35
  ]
36
 
 
37
  with gr.Blocks() as demo:
38
  gr.Markdown(
39
  """# Sound search database """
40
  )
41
  inp = gr.Textbox(placeholder="What sound are you looking for ?")
42
- out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Nécessaire pour avoir différents objets
43
  inp.change(sound_search, inp, out)
44
 
45
  demo.launch()
 
 
 
 
1
  import os
2
 
3
+ import gradio as gr
4
+ from qdrant_client import QdrantClient
5
+ from transformers import ClapModel, ClapProcessor
6
 
7
+ # Loading the Qdrant DB in local ###################################################################
8
+ client = QdrantClient("https://ebe79742-e3ac-4d09-a2c6-63946024cc7a.us-east4-0.gcp.cloud.qdrant.io", api_key="_NnGLuSMH4Qwv-ancoFh88YvzuR7WbyidAorVOVQ_eMCbPhxTb2TSw")
9
  print("[INFO] Client created...")
10
 
11
+ # loading the model
12
  print("[INFO] Loading the model...")
13
+ model_name = "laion/larger_clap_general"
14
+ model = ClapModel.from_pretrained(model_name)
15
+ processor = ClapProcessor.from_pretrained(model_name)
16
 
17
+ # Gradio Interface #################################################################################
18
  max_results = 10
19
 
20
+
21
  def sound_search(query):
22
+ text_inputs = processor(text=query, return_tensors="pt")
23
+ text_embed = model.get_text_features(**text_inputs)[0]
24
+
25
  hits = client.search(
26
+ collection_name="demo_spaces_db",
27
  query_vector=text_embed,
28
  limit=max_results,
29
  )
 
34
  for hit in hits
35
  ]
36
 
37
+
38
  with gr.Blocks() as demo:
39
  gr.Markdown(
40
  """# Sound search database """
41
  )
42
  inp = gr.Textbox(placeholder="What sound are you looking for ?")
43
+ out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Necessary to have different objs
44
  inp.change(sound_search, inp, out)
45
 
46
  demo.launch()