sadafwalliyani commited on
Commit
8f13b05
·
verified ·
1 Parent(s): b2df4c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -45
app.py CHANGED
@@ -1,28 +1,27 @@
 
1
  import streamlit as st
2
  import torch
3
  import torchaudio
4
- # from audiocraft.models import MusicGen
5
  import os
6
  import numpy as np
7
  import base64
8
- from audiocraft.models import MAGNeT
9
- from audiocraft.data.audio import audio_write
10
 
11
- from audiocraft.data.audio_utils import convert_audio
12
- from audiocraft.data.audio import audio_write
13
- from audiocraft.models.encodec import InterleaveStereoCompressionModel
14
- from audiocraft.models import MusicGen, MultiBandDiffusion
 
 
15
 
16
- genres = ["Pop","Hip-Hop", "Classical","Lofi", "Chillpop","Country","R&G", "Folk","EDM", "Disco", "House", "Techno",]
17
 
18
  @st.cache_resource()
19
  def load_model():
20
- model = MAGNeT.get_pretrained("facebook/magnet-medium-30secs")
21
- # model = MusicGen.get_pretrained('facebook/audiogen-medium')
22
  return model
23
-
24
 
25
- def generate_music_tensors(description, duration: int, batch_size=1):
26
  model = load_model()
27
 
28
  model.set_generation_params(
@@ -40,9 +39,8 @@ def generate_music_tensors(description, duration: int, batch_size=1):
40
 
41
  st.success("Music Generation Complete!")
42
  return output
43
-
44
 
45
- def save_audio(samples: torch.Tensor, filename):
46
  sample_rate = 30000
47
  save_path = "audio_output"
48
  assert samples.dim() == 2 or samples.dim() == 3
@@ -52,9 +50,8 @@ def save_audio(samples: torch.Tensor, filename):
52
  samples = samples[None, ...]
53
 
54
  for idx, audio in enumerate(samples):
55
- audio_path = os.path.join(save_path, f"{filename}_{idx}.wav")
56
  torchaudio.save(audio_path, audio, sample_rate)
57
- return audio_path
58
 
59
  def get_binary_file_downloader_html(bin_file, file_label='File'):
60
  with open(bin_file, 'rb') as f:
@@ -72,49 +69,53 @@ def main():
72
  st.title("🎧 AI Composer Medium-Model 🎧")
73
 
74
  st.subheader("Craft your perfect melody!")
75
-
76
  bpm = st.number_input("Enter Speed in BPM", min_value=60)
77
- text_area = st.text_area('Example: 80s rock song with guitar and drums', height=50)
78
- selected_genre = st.selectbox("Select Genre (Optional)", genres, None)
 
 
 
 
 
79
  time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
80
-
81
- mood = st.selectbox("Select Mood (Optional)", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"], None)
82
- instrument = st.selectbox("Select Instrument (Optional)", ["Piano", "Guitar", "Flute", "Violin", "Drums"], None)
83
- tempo = st.selectbox("Select Tempo (Optional)", ["Slow", "Moderate", "Fast"], None)
84
- melody = st.text_input("Enter Melody or Chord Progression (Optional)", "e.g: C D:min G:7 C, Twinkle Twinkle Little Star")
85
 
86
  if st.button('Let\'s Generate 🎶'):
87
  st.text('\n\n')
88
  st.subheader("Generated Music")
89
-
90
- description = f"{text_area}"
 
91
  if selected_genre:
92
  description += f" {selected_genre}"
 
93
  if bpm:
94
  description += f" {bpm} BPM"
95
- if mood:
96
- description += f" {mood}"
97
- if instrument:
98
- description += f" {instrument}"
99
- if tempo:
100
- description += f" {tempo}"
101
- if melody:
102
- description += f" {melody}"
 
 
 
 
 
 
103
 
104
  music_tensors = generate_music_tensors(description, time_slider)
105
 
 
106
  idx = 0
107
-
108
- # audio_path = save_audio(music_tensors[idx], "audio_output")
109
- # audio_file = open(audio_path, 'rb')
110
- # audio_bytes = audio_file.read()
111
-
112
- # st.audio(audio_bytes, format='audio/wav')
113
- # st.markdown(get_binary_file_downloader_html(audio_path, f'Audio_{idx}'), unsafe_allow_html=True)
114
-
115
  music_tensor = music_tensors[idx]
116
- save_music_file = save_audio(music_tensor)
117
- audio_filepath = f'audio_output/audio_{idx}.wav'
118
  audio_file = open(audio_filepath, 'rb')
119
  audio_bytes = audio_file.read()
120
 
@@ -123,4 +124,6 @@ def main():
123
  st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
124
 
125
  if __name__ == "__main__":
126
- main()
 
 
 
1
+
2
  import streamlit as st
3
  import torch
4
  import torchaudio
5
+ from audiocraft.models import MusicGen
6
  import os
7
  import numpy as np
8
  import base64
 
 
9
 
10
+ # Before
11
+ batch_size = 64
12
+
13
+ # After
14
+ batch_size = 32
15
+ torch.cuda.empty_cache()
16
 
17
+ genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
18
 
19
  @st.cache_resource()
20
  def load_model():
21
+ model = MusicGen.get_pretrained('facebook/musicgen-medium')
 
22
  return model
 
23
 
24
+ def generate_music_tensors(description, duration: int):
25
  model = load_model()
26
 
27
  model.set_generation_params(
 
39
 
40
  st.success("Music Generation Complete!")
41
  return output
 
42
 
43
+ def save_audio(samples: torch.Tensor):
44
  sample_rate = 30000
45
  save_path = "audio_output"
46
  assert samples.dim() == 2 or samples.dim() == 3
 
50
  samples = samples[None, ...]
51
 
52
  for idx, audio in enumerate(samples):
53
+ audio_path = os.path.join(save_path, f"audio_{idx}.wav")
54
  torchaudio.save(audio_path, audio, sample_rate)
 
55
 
56
  def get_binary_file_downloader_html(bin_file, file_label='File'):
57
  with open(bin_file, 'rb') as f:
 
69
  st.title("🎧 AI Composer Medium-Model 🎧")
70
 
71
  st.subheader("Craft your perfect melody!")
 
72
  bpm = st.number_input("Enter Speed in BPM", min_value=60)
73
+
74
+ text_area = st.text_area('Ex : 80s rock song with guitar and drums')
75
+ st.text('')
76
+ # Dropdown for genres
77
+ selected_genre = st.selectbox("Select Genre", genres)
78
+
79
+ st.subheader("2. Select time duration (In Seconds)")
80
  time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
81
+ # mood = st.selectbox("Select Mood (Optional)", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"], None)
82
+ # instrument = st.selectbox("Select Instrument (Optional)", ["Piano", "Guitar", "Flute", "Violin", "Drums"], None)
83
+ # tempo = st.selectbox("Select Tempo (Optional)", ["Slow", "Moderate", "Fast"], None)
84
+ # melody = st.text_input("Enter Melody or Chord Progression (Optional)", "e.g: C D:min G:7 C, Twinkle Twinkle Little Star")
 
85
 
86
  if st.button('Let\'s Generate 🎶'):
87
  st.text('\n\n')
88
  st.subheader("Generated Music")
89
+
90
+ # Generate audio
91
+ description = text_area # Initialize description with text_area
92
  if selected_genre:
93
  description += f" {selected_genre}"
94
+ st.empty() # Hide the selected_genre selectbox after selecting one option
95
  if bpm:
96
  description += f" {bpm} BPM"
97
+ # if mood:
98
+ # description += f" {mood}"
99
+ # st.empty() # Hide the mood selectbox after selecting one option
100
+ # if instrument:
101
+ # description += f" {instrument}"
102
+ # st.empty() # Hide the instrument selectbox after selecting one option
103
+ # if tempo:
104
+ # description += f" {tempo}"
105
+ # st.empty() # Hide the tempo selectbox after selecting one option
106
+ # if melody:
107
+ # description += f" {melody}"
108
+
109
+ # Clear CUDA memory cache before generating music
110
+ torch.cuda.empty_cache()
111
 
112
  music_tensors = generate_music_tensors(description, time_slider)
113
 
114
+ # Only play the full audio for index 0
115
  idx = 0
 
 
 
 
 
 
 
 
116
  music_tensor = music_tensors[idx]
117
+ save_audio(music_tensor)
118
+ audio_filepath = f'/audio_output/audio_{idx}.wav'
119
  audio_file = open(audio_filepath, 'rb')
120
  audio_bytes = audio_file.read()
121
 
 
124
  st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
125
 
126
  if __name__ == "__main__":
127
+ main()
128
+
129
+