Spaces:
Running
Running
File size: 3,710 Bytes
2fff01c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import streamlit as st
import torch
import torchaudio
from audiocraft.models import MusicGen
import os
import numpy as np
import base64
@st.cache_resource()
def load_model():
model = MusicGen.get_pretrained('facebook/musicgen-small')
return model
@st.cache_resource()
def generate_music_tensors(description, duration: int):
model = load_model()
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration
)
output = model.generate(
descriptions=[description],
progress=True,
return_tokens=True
)
return output[0]
def save_audio(samples: torch.Tensor):
"""Renders an audio player for the given audio samples and saves them to a local directory.
Args:
samples (torch.Tensor): a Tensor of decoded audio samples
with shapes [B, C, T] or [C, T]
sample_rate (int): sample rate audio should be displayed with.
save_path (str): path to the directory where audio should be saved.
"""
print("Samples (inside function): ", samples)
sample_rate = 30000
save_path = "audio_output/"
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
for idx, audio in enumerate(samples):
audio_path = os.path.join(save_path, f"audio_{idx}.wav")
torchaudio.save(audio_path, audio, sample_rate)
def get_binary_file_downloader_html(bin_file, file_label='File'):
with open(bin_file, 'rb') as f:
data = f.read()
bin_str = base64.b64encode(data).decode()
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
return href
st.set_page_config(
page_icon= "musical_note",
page_title= "Music Gen"
)
def main():
with st.sidebar:
st.header("""⚙️ Parameters ⚙️""",divider="rainbow")
st.text("")
st.subheader("1. Enter your music description.......")
text_area = st.text_area('Ex : 80s rock song with guitar and drums')
st.text('')
st.subheader("2. Select time duration (In Seconds)")
time_slider = st.slider("Select time duration (In Seconds)", 0, 20, 10)
st.title("""🎵 Text to Music Generator 🎵""")
st.text('')
left_co,right_co = st.columns(2)
left_co.write("""Music Generation using Meta AI, through a prompt""")
left_co.write(("""PS : First generation may take some time as it loads the full model and requirements"""))
#container1 = st.container()
#container1.write("""Music coupled with Image Generation using a prompt""")
#container1.write("""PS : First generation may take some time as it loads the full model and requirements""")
if st.sidebar.button('Generate !'):
gif_url = "https://media.giphy.com/media/26Fffy7jqQW8gVg8o/giphy.gif"
with right_co:
with st.spinner("Generating"):
st.image(gif_url,width=250)
with left_co:
st.text('')
st.text('')
st.text('')
st.text('')
st.text('')
st.text('')
st.subheader("Generated Music")
music_tensors = generate_music_tensors(text_area, time_slider)
save_music_file = save_audio(music_tensors)
audio_filepath = 'audio_output/audio_0.wav'
audio_file = open(audio_filepath, 'rb')
audio_bytes = audio_file.read()
st.audio(audio_bytes)
st.markdown(get_binary_file_downloader_html(audio_filepath, 'Audio'), unsafe_allow_html=True)
if __name__ == "__main__":
main()
|