fffiloni commited on
Commit
729a510
β€’
1 Parent(s): e3fb84a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -40
app.py CHANGED
@@ -1,15 +1,16 @@
 
 
1
  import time
2
  import base64
3
- import gradio as gr
4
  from sentence_transformers import SentenceTransformer
5
-
6
  import httpx
7
  import json
8
-
9
  import os
10
  import requests
11
  import urllib
12
-
13
  from os import path
14
  from pydub import AudioSegment
15
 
@@ -21,25 +22,16 @@ img_to_text = gr.Blocks.load(name="spaces/fffiloni/CLIP-Interrogator-2")
21
 
22
  from share_btn import community_icon_html, loading_icon_html, share_js
23
  from utils import get_tags_for_prompts, get_mubert_tags_embeddings
 
24
  minilm = SentenceTransformer('all-MiniLM-L6-v2')
25
  mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)
26
 
27
- def get_prompts(uploaded_image, track_duration, gen_intensity, gen_mode):
28
- print("calling clip interrogator")
29
- #prompt = img_to_text(uploaded_image, "ViT-L (best for Stable Diffusion 1.*)", "fast", fn_index=1)[0]
30
- prompt = img_to_text(uploaded_image, 'best', 4, fn_index=1)[0]
31
- print(prompt)
32
- pat = get_pat_token()
33
- music_result = get_music(pat, prompt, track_duration, gen_intensity, gen_mode)
34
- #music_result = generate_track_by_prompt(pat, prompt, track_duration, gen_intensity, gen_mode)
35
- #print(pat
36
- time.sleep(1)
37
- return music_result, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
38
-
39
-
40
-
41
 
 
 
42
 
 
43
  def get_pat_token():
44
  r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess',
45
  json={
@@ -54,11 +46,10 @@ def get_pat_token():
54
  })
55
 
56
  rdata = json.loads(r.text)
57
- #print(rdata)
58
  assert rdata['status'] == 1, "probably incorrect e-mail"
59
- #pat = rdata['data']['pat']
60
- print(rdata['data']['pat'])
61
- return rdata['data']['pat']
62
 
63
  def get_music(pat, prompt, track_duration, gen_intensity, gen_mode):
64
 
@@ -69,32 +60,59 @@ def get_music(pat, prompt, track_duration, gen_intensity, gen_mode):
69
  {
70
  "text": prompt,
71
  "pat": pat,
72
- "mode":"track",
73
  "duration":track_duration,
 
74
  }
75
  })
76
 
77
  rdata = json.loads(r.text)
78
 
79
- #print(track)
80
  assert rdata['status'] == 1, rdata['error']['text']
81
- #track = rdata['data']['tasks']['download_link']
82
-
83
- print(rdata)
84
- time.sleep(2)
85
- track=rdata['data']['tasks'][0]['download_link']
86
- return track
87
-
88
- #print('Generating track ', end='')
89
- #for i in range(20):
90
-
91
- # r = httpx.get(track)
92
- # if r.status_code == 200:
93
- # return track
94
- # time.sleep(1)
95
-
96
-
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def get_track_by_tags(tags, pat, duration, gen_intensity, gen_mode, maxit=20):
99
 
100
  r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
 
1
+ import gradio as gr
2
+ import numpy as np
3
  import time
4
  import base64
5
+ import ffmpeg
6
  from sentence_transformers import SentenceTransformer
7
+ from audio2numpy import open_audio
8
  import httpx
9
  import json
 
10
  import os
11
  import requests
12
  import urllib
13
+ import pydub
14
  from os import path
15
  from pydub import AudioSegment
16
 
 
22
 
23
  from share_btn import community_icon_html, loading_icon_html, share_js
24
  from utils import get_tags_for_prompts, get_mubert_tags_embeddings
25
+
26
  minilm = SentenceTransformer('all-MiniLM-L6-v2')
27
  mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)
28
 
29
+ ##β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ MUBERT_LICENSE = os.environ.get('MUBERT_LICENSE')
32
+ MUBERT_TOKEN = os.environ.get('MUBERT_TOKEN')
33
 
34
+ ##β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
35
  def get_pat_token():
36
  r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess',
37
  json={
 
46
  })
47
 
48
  rdata = json.loads(r.text)
 
49
  assert rdata['status'] == 1, "probably incorrect e-mail"
50
+ pat = rdata['data']['pat']
51
+ print(f"pat: {pat}")
52
+ return pat
53
 
54
  def get_music(pat, prompt, track_duration, gen_intensity, gen_mode):
55
 
 
60
  {
61
  "text": prompt,
62
  "pat": pat,
63
+ "mode":gen_mode,
64
  "duration":track_duration,
65
+ "intensity": gen_intensity
66
  }
67
  })
68
 
69
  rdata = json.loads(r.text)
70
 
71
+ print(f"rdata: {rdata}")
72
  assert rdata['status'] == 1, rdata['error']['text']
73
+ track = rdata['data']['tasks'][0]['download_link']
74
+ print(track)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ local_file_path = "sample.mp3"
77
+
78
+ # Download the MP3 file from the URL
79
+ headers = {
80
+ 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7; rv:93.0) Gecko/20100101 Firefox/93.0'}
81
+
82
+ retries = 3
83
+ delay = 5 # in seconds
84
+ while retries > 0:
85
+ response = requests.get(track, headers=headers)
86
+ if response.status_code == 200:
87
+ break
88
+ retries -= 1
89
+ time.sleep(delay)
90
+ response = requests.get(track, headers=headers)
91
+ print(f"{response}")
92
+ # Save the downloaded content to a local file
93
+ with open(local_file_path, 'wb') as f:
94
+ f.write(response.content)
95
+ return "sample.mp3"
96
+
97
+
98
+ def get_results(text_prompt,track_duration,gen_intensity,gen_mode):
99
+ pat_token = get_pat_token()
100
+ music = get_music(pat_token, text_prompt, track_duration, gen_intensity, gen_mode)
101
+ return pat_token, music
102
+
103
+ def get_prompts(uploaded_image, track_duration, gen_intensity, gen_mode):
104
+ print("calling clip interrogator")
105
+ #prompt = img_to_text(uploaded_image, "ViT-L (best for Stable Diffusion 1.*)", "fast", fn_index=1)[0]
106
+ prompt = img_to_text(uploaded_image, 'best', 4, fn_index=1)[0]
107
+ print(prompt)
108
+ music_result = get_results(prompt, track_duration, gen_intensity, gen_mode)
109
+ wave_file = convert_mp3_to_wav(music_result[1])
110
+ #music_result = generate_track_by_prompt(pat, prompt, track_duration, gen_intensity, gen_mode)
111
+ #print(pat
112
+ time.sleep(1)
113
+ return wave_file, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
114
+
115
+
116
  def get_track_by_tags(tags, pat, duration, gen_intensity, gen_mode, maxit=20):
117
 
118
  r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',