sadafwalliyani commited on
Commit
3d2ca89
·
verified ·
1 Parent(s): a23cc4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -48
app.py CHANGED
@@ -1,61 +1,42 @@
1
  import streamlit as st
2
  import torch
3
- import torchaudio
4
  import os
5
- import numpy as np
6
  import base64
7
- import math
8
-
9
- from audiocraft.data.audio_utils import convert_audio
10
- from audiocraft.data.audio import audio_write
11
- # from audiocraft.models.encodec import InterleaveStereoCompressionModel
12
- from audiocraft.models import MusicGen, MultiBandDiffusion
13
- from audiocraft.utils.notebook import display_audio
14
  from audiocraft.models import MusicGen
15
- # from audiocraft.models import audiogen
16
 
17
  genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
18
  "Lofi", "Chillpop","Country","R&G", "Folk","EDM", "Disco", "House", "Techno",]
19
 
20
  @st.cache_resource()
21
- def load_model():
22
- model = MusicGen.get_pretrained('facebook/musicgen-medium')
23
  return model
24
-
25
 
26
- def generate_music_tensors(description, duration: int, batch_size=1):
27
- model = load_model()
28
-
29
- model.set_generation_params(
30
- use_sampling=True,
31
- top_k=250,
32
- duration=duration
33
- )
34
-
35
- with st.spinner("Generating Music..."):
36
- output = []
37
- for i in range(0, len(description), batch_size):
38
- batch_descriptions = description[i:i+batch_size]
39
- batch_output = model.generate(
40
- descriptions=batch_descriptions,
41
  progress=True,
42
  return_tokens=True
43
  )
44
- output.extend(batch_output)
45
-
46
- # output = model.generate(
47
- # descriptions=description,
48
- # progress=True,
49
- # return_tokens=True
50
- # )
51
 
52
  st.success("Music Generation Complete!")
53
- return output
54
-
55
 
56
  def save_audio(samples: torch.Tensor, filename):
57
  sample_rate = 30000
58
- save_path = "/content/drive/MyDrive/Colab Notebooks/audio_output"
59
  assert samples.dim() == 2 or samples.dim() == 3
60
 
61
  samples = samples.detach().cpu()
@@ -65,7 +46,7 @@ def save_audio(samples: torch.Tensor, filename):
65
  for idx, audio in enumerate(samples):
66
  audio_path = os.path.join(save_path, f"{filename}_{idx}.wav")
67
  torchaudio.save(audio_path, audio, sample_rate)
68
- return audio_path
69
 
70
  def get_binary_file_downloader_html(bin_file, file_label='File'):
71
  with open(bin_file, 'rb') as f:
@@ -80,7 +61,7 @@ st.set_page_config(
80
  )
81
 
82
  def main():
83
- st.title("🎧AI Composer Medium-Model 🎧")
84
 
85
  st.subheader("Generate Music")
86
  st.write("Craft your perfect melody! Fill in the blanks below to create your music masterpiece:")
@@ -95,21 +76,27 @@ def main():
95
  tempo = st.selectbox("Select Tempo", ["Slow", "Moderate", "Fast"])
96
  melody = st.text_input("Enter Melody or Chord Progression", "e.g., C D:min G:7 C, Twinkle Twinkle Little Star")
97
 
 
 
 
 
 
 
98
  if st.button('Let\'s Generate 🎶'):
99
  st.text('\n\n')
100
  st.subheader("Generated Music")
101
 
102
  description = f"{text_area} {selected_genre} {bpm} BPM {mood} {instrument} {tempo} {melody}"
103
- music_tensors = generate_music_tensors(description, time_slider, batch_size=2)
104
-
105
 
106
- idx = 0
107
- audio_path = save_audio(music_tensors[idx], "audio_output")
108
- audio_file = open(audio_path, 'rb')
109
- audio_bytes = audio_file.read()
 
110
 
111
- st.audio(audio_bytes, format='audio/wav')
112
- st.markdown(get_binary_file_downloader_html(audio_path, f'Audio_{idx}'), unsafe_allow_html=True)
113
 
114
  if __name__ == "__main__":
115
  main()
 
1
  import streamlit as st
2
  import torch
 
3
  import os
 
4
  import base64
5
+ import torchaudio
6
+ import numpy as np
 
 
 
 
 
7
  from audiocraft.models import MusicGen
 
8
 
9
  genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
10
  "Lofi", "Chillpop","Country","R&G", "Folk","EDM", "Disco", "House", "Techno",]
11
 
12
  @st.cache_resource()
13
+ def load_model(model_name):
14
+ model = MusicGen.get_pretrained(model_name)
15
  return model
 
16
 
17
+ def generate_music_tensors(description, duration: int, batch_size=1, models=None):
18
+ outputs = {}
19
+ for model_name, model in models.items():
20
+ model.set_generation_params(
21
+ use_sampling=True,
22
+ top_k=250,
23
+ duration=duration
24
+ )
25
+
26
+ with st.spinner(f"Generating Music with {model_name}..."):
27
+ output = model.generate(
28
+ descriptions=description,
 
 
 
29
  progress=True,
30
  return_tokens=True
31
  )
32
+ outputs[model_name] = output
 
 
 
 
 
 
33
 
34
  st.success("Music Generation Complete!")
35
+ return outputs
 
36
 
37
  def save_audio(samples: torch.Tensor, filename):
38
  sample_rate = 30000
39
+ save_path = "audio_output"
40
  assert samples.dim() == 2 or samples.dim() == 3
41
 
42
  samples = samples.detach().cpu()
 
46
  for idx, audio in enumerate(samples):
47
  audio_path = os.path.join(save_path, f"{filename}_{idx}.wav")
48
  torchaudio.save(audio_path, audio, sample_rate)
49
+ return audio_path
50
 
51
  def get_binary_file_downloader_html(bin_file, file_label='File'):
52
  with open(bin_file, 'rb') as f:
 
61
  )
62
 
63
  def main():
64
+ st.title("🎧 AI Composer 🎧")
65
 
66
  st.subheader("Generate Music")
67
  st.write("Craft your perfect melody! Fill in the blanks below to create your music masterpiece:")
 
76
  tempo = st.selectbox("Select Tempo", ["Slow", "Moderate", "Fast"])
77
  melody = st.text_input("Enter Melody or Chord Progression", "e.g., C D:min G:7 C, Twinkle Twinkle Little Star")
78
 
79
+ models = {
80
+ 'Medium': load_model('facebook/musicgen-medium'),
81
+ 'Large': load_model('facebook/musicgen-large'),
82
+ # Add more models here as needed
83
+ }
84
+
85
  if st.button('Let\'s Generate 🎶'):
86
  st.text('\n\n')
87
  st.subheader("Generated Music")
88
 
89
  description = f"{text_area} {selected_genre} {bpm} BPM {mood} {instrument} {tempo} {melody}"
90
+ music_outputs = generate_music_tensors(description, time_slider, batch_size=2, models=models)
 
91
 
92
+ for model_name, output in music_outputs.items():
93
+ idx = 0 # Assuming you want to access the first audio file for each model
94
+ audio_filepath = save_audio(output, f'audio_{model_name}_{idx}')
95
+ audio_file = open(audio_filepath, 'rb')
96
+ audio_bytes = audio_file.read()
97
 
98
+ st.audio(audio_bytes, format='audio/wav', label=f'{model_name} Model')
99
+ st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{model_name}_{idx}'), unsafe_allow_html=True)
100
 
101
  if __name__ == "__main__":
102
  main()