Spaces:
Runtime error
Runtime error
import time | |
import base64 | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
import httpx | |
import json | |
from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat | |
minilm = SentenceTransformer('all-MiniLM-L6-v2') | |
mubert_tags_embeddings = get_mubert_tags_embeddings(minilm) | |
def get_track_by_tags(tags, pat, duration, maxit=20, loop=False): | |
if loop: | |
mode = "loop" | |
else: | |
mode = "track" | |
r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM', | |
json={ | |
"method": "RecordTrackTTM", | |
"params": { | |
"pat": pat, | |
"duration": duration, | |
"tags": tags, | |
"mode": mode | |
} | |
}) | |
rdata = json.loads(r.text) | |
assert rdata['status'] == 1, rdata['error']['text'] | |
trackurl = rdata['data']['tasks'][0]['download_link'] | |
print('Generating track ', end='') | |
for i in range(maxit): | |
r = httpx.get(trackurl) | |
if r.status_code == 200: | |
return trackurl | |
time.sleep(1) | |
def generate_track_by_prompt(prompt): | |
try: | |
pat = get_pat("[email protected]") | |
_, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0] | |
result = get_track_by_tags(tags, pat, int(30), loop=False) | |
print(result) | |
return result | |
except Exception as e: | |
return str(e) | |
iface = gr.Interface(fn=generate_track_by_prompt, inputs=["text"], outputs=[gr.Text(label="Result")]) | |
iface.queue(max_size=32, concurrency_count=20) | |
iface.launch() |