|
|
|
import gradio as gr
|
|
import librosa
|
|
import numpy as np
|
|
from openvino import runtime as ov
|
|
import soundfile as sf
|
|
import warnings
|
|
import os
|
|
from pathlib import Path
|
|
|
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
|
warnings.filterwarnings('ignore', category=UserWarning)
|
|
|
|
def estimate_key(y, sr):
|
|
"""Estimate the musical key using chroma features."""
|
|
chroma = librosa.feature.chroma_cqt(y=y, sr=sr)
|
|
chroma_avg = np.mean(chroma, axis=1)
|
|
keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
|
key_index = np.argmax(chroma_avg)
|
|
return keys[key_index]
|
|
|
|
def classify_instrument(spectral_centroid, rms_energy):
|
|
"""Classify instrument type based on spectral characteristics."""
|
|
if spectral_centroid < 500:
|
|
if rms_energy > 0.1:
|
|
return "bass"
|
|
return "sub"
|
|
elif spectral_centroid < 2000:
|
|
if rms_energy > 0.15:
|
|
return "drums"
|
|
return "perc"
|
|
elif spectral_centroid < 4000:
|
|
return "synth"
|
|
else:
|
|
return "high"
|
|
|
|
def get_musical_tempo_description(tempo):
|
|
"""Convert numerical tempo to musical description."""
|
|
if tempo < 70:
|
|
return "slow"
|
|
elif tempo < 100:
|
|
return "chill"
|
|
elif tempo < 120:
|
|
return "upbeat"
|
|
elif tempo < 140:
|
|
return "energetic"
|
|
else:
|
|
return "fast"
|
|
|
|
def generate_prompt(keys, avg_tempo, streams_info, genre="electronic"):
|
|
"""Generate a concise, Suno-friendly prompt under 200 characters."""
|
|
most_common_key = max(set(keys), key=keys.count) if keys else "C"
|
|
|
|
instrument_counts = {}
|
|
for info in streams_info:
|
|
inst_type = info['type']
|
|
instrument_counts[inst_type] = instrument_counts.get(inst_type, 0) + 1
|
|
|
|
main_elements = [k for k, v in sorted(instrument_counts.items(), key=lambda x: x[1], reverse=True)[:2]]
|
|
tempo_desc = get_musical_tempo_description(avg_tempo)
|
|
|
|
prompt = f"{most_common_key} {int(avg_tempo)}bpm {tempo_desc} {genre} with {' + '.join(main_elements)}, dark atmosphere + reverb"
|
|
|
|
if len(prompt) > 200:
|
|
prompt = prompt[:197] + "..."
|
|
|
|
return prompt
|
|
|
|
def process_audio(audio_path, genre):
|
|
"""Process audio file and generate prompt."""
|
|
try:
|
|
|
|
y, sr = librosa.load(audio_path, sr=None)
|
|
print(f"Audio loaded: {len(y)} samples, Sample rate: {sr}")
|
|
|
|
|
|
model_path = os.path.join(os.path.dirname(__file__), "models", "htdemucs_v4.xml")
|
|
core = ov.Core()
|
|
model = core.read_model(model_path)
|
|
compiled_model = core.compile_model(model, "CPU")
|
|
|
|
input_node = compiled_model.input(0)
|
|
output_node = compiled_model.output(0)
|
|
target_shape = (1, 4, 2048, 336)
|
|
|
|
total_size = np.prod(target_shape)
|
|
if len(y) < total_size:
|
|
input_data = np.pad(y, (0, total_size - len(y)), mode='constant')
|
|
else:
|
|
input_data = y[:total_size]
|
|
|
|
input_data = input_data.reshape(target_shape).astype(np.float32)
|
|
input_tensor = ov.Tensor(input_data)
|
|
|
|
outputs = compiled_model([input_tensor])[output_node]
|
|
separated_audios = outputs[0]
|
|
|
|
|
|
keys = []
|
|
avg_tempos = []
|
|
streams_info = []
|
|
|
|
|
|
temp_dir = Path("temp_streams")
|
|
temp_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
for i in range(separated_audios.shape[0]):
|
|
stream = separated_audios[i].reshape(-1)
|
|
|
|
try:
|
|
output_file = temp_dir / f'separated_stream_{i+1}.wav'
|
|
sf.write(str(output_file), stream, sr)
|
|
|
|
y_s, sr_s = librosa.load(str(output_file), sr=None)
|
|
|
|
if len(y_s) < sr_s * 0.1:
|
|
continue
|
|
|
|
|
|
tempo_s, _ = librosa.beat.beat_track(y=y_s, sr=sr_s)
|
|
spectral_centroid_s = np.mean(librosa.feature.spectral_centroid(y=y_s, sr=sr_s))
|
|
rms_s = np.mean(librosa.feature.rms(y=y_s))
|
|
key_s = estimate_key(y_s, sr_s)
|
|
|
|
|
|
streams_info.append({
|
|
'type': classify_instrument(spectral_centroid_s, rms_s),
|
|
'centroid': spectral_centroid_s,
|
|
'energy': rms_s
|
|
})
|
|
|
|
keys.append(key_s)
|
|
avg_tempos.append(tempo_s)
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Could not process stream {i+1}: {str(e)}")
|
|
continue
|
|
finally:
|
|
|
|
if output_file.exists():
|
|
output_file.unlink()
|
|
|
|
|
|
temp_dir.rmdir()
|
|
|
|
if len(avg_tempos) > 0:
|
|
avg_tempo = np.mean(avg_tempos)
|
|
prompt = generate_prompt(keys, avg_tempo, streams_info, genre)
|
|
return prompt, f"Character count: {len(prompt)}"
|
|
else:
|
|
return "Error: No valid audio streams were processed.", "Processing failed"
|
|
|
|
except Exception as e:
|
|
return f"Error processing the file: {str(e)}", "Processing failed"
|
|
|
|
|
|
def create_interface():
|
|
genre_choices = ["electronic", "ambient", "trap", "synthwave", "house", "techno"]
|
|
|
|
iface = gr.Interface(
|
|
fn=process_audio,
|
|
inputs=[
|
|
gr.Audio(type="filepath", label="Upload Audio File"),
|
|
gr.Dropdown(choices=genre_choices, label="Select Genre", value="electronic")
|
|
],
|
|
outputs=[
|
|
gr.Textbox(label="Generated Prompt"),
|
|
gr.Textbox(label="Status")
|
|
],
|
|
title="Audio Analysis to Suno Prompt Generator",
|
|
description="Upload an audio file to generate a Suno-compatible prompt based on its musical characteristics.",
|
|
examples=[],
|
|
cache_examples=False
|
|
)
|
|
return iface
|
|
|
|
|
|
if __name__ == "__main__":
|
|
iface = create_interface()
|
|
iface.launch() |