Spaces:
Sleeping
Sleeping
File size: 4,232 Bytes
8f13b05 b356ee6 18b406b e9c4869 b2df4c4 8f13b05 2dcfa57 8f13b05 18b406b 8f13b05 18b406b 8f13b05 18b406b d0d0a63 be83614 18b406b 8f13b05 18b406b 64a7e15 18b406b 8f13b05 18b406b 9d8eaeb d0d0a63 be83614 d0d0a63 8f13b05 9d8eaeb 8f13b05 18b406b d0d0a63 8f13b05 be83614 8f13b05 be83614 8f13b05 be83614 d0d0a63 8f13b05 d0d0a63 be83614 8f13b05 be83614 d0d0a63 be83614 d0d0a63 be83614 18b406b 8f13b05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import streamlit as st
import torch
import torchaudio
import os
import numpy as np
import base64
from audiocraft.models import MusicGen
# Before
batch_size = 64
# After
batch_size = 32
torch.cuda.empty_cache()
genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
@st.cache_resource()
def load_model():
model = MusicGen.get_pretrained('facebook/musicgen-medium')
return model
def generate_music_tensors(description, duration: int):
model = load_model()
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration
)
with st.spinner("Generating Music..."):
output = model.generate(
descriptions=description,
progress=True,
return_tokens=True
)
st.success("Music Generation Complete!")
return output
def save_audio(samples: torch.Tensor):
sample_rate = 30000
save_path = "audio_output"
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
for idx, audio in enumerate(samples):
audio_path = os.path.join(save_path, f"audio_{idx}.wav")
torchaudio.save(audio_path, audio, sample_rate)
def get_binary_file_downloader_html(bin_file, file_label='File'):
with open(bin_file, 'rb') as f:
data = f.read()
bin_str = base64.b64encode(data).decode()
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
return href
st.set_page_config(
page_icon= "musical_note",
page_title= "Music Gen"
)
def main():
st.title("🎧 AI Composer Medium-Model 🎧")
st.subheader("Craft your perfect melody!")
bpm = st.number_input("Enter Speed in BPM", min_value=60)
text_area = st.text_area('Ex : 80s rock song with guitar and drums')
st.text('')
# Dropdown for genres
selected_genre = st.selectbox("Select Genre", genres)
st.subheader("2. Select time duration (In Seconds)")
time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
# mood = st.selectbox("Select Mood (Optional)", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"], None)
# instrument = st.selectbox("Select Instrument (Optional)", ["Piano", "Guitar", "Flute", "Violin", "Drums"], None)
# tempo = st.selectbox("Select Tempo (Optional)", ["Slow", "Moderate", "Fast"], None)
# melody = st.text_input("Enter Melody or Chord Progression (Optional)", "e.g: C D:min G:7 C, Twinkle Twinkle Little Star")
if st.button('Let\'s Generate 🎶'):
st.text('\n\n')
st.subheader("Generated Music")
# Generate audio
description = text_area # Initialize description with text_area
if selected_genre:
description += f" {selected_genre}"
st.empty() # Hide the selected_genre selectbox after selecting one option
if bpm:
description += f" {bpm} BPM"
# if mood:
# description += f" {mood}"
# st.empty() # Hide the mood selectbox after selecting one option
# if instrument:
# description += f" {instrument}"
# st.empty() # Hide the instrument selectbox after selecting one option
# if tempo:
# description += f" {tempo}"
# st.empty() # Hide the tempo selectbox after selecting one option
# if melody:
# description += f" {melody}"
# Clear CUDA memory cache before generating music
torch.cuda.empty_cache()
music_tensors = generate_music_tensors(description, time_slider)
# Only play the full audio for index 0
idx = 0
music_tensor = music_tensors[idx]
save_audio(music_tensor)
audio_filepath = f'/audio_output/audio_{idx}.wav'
audio_file = open(audio_filepath, 'rb')
audio_bytes = audio_file.read()
# Play the full audio
st.audio(audio_bytes, format='audio/wav')
st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
if __name__ == "__main__":
main()
|