KingNish commited on
Commit
b8a38aa
·
1 Parent(s): 773a80a

modified: app.py

Browse files
Files changed (1) hide show
  1. app.py +354 -113
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import subprocess
3
- import os
4
  import shutil
5
  import tempfile
6
  import spaces
@@ -27,10 +27,10 @@ def install_flash_attn():
27
  # Install flash-attn
28
  install_flash_attn()
29
 
30
- from huggingface_hub import snapshot_download
31
 
32
  # Create xcodec_mini_infer folder
33
- folder_path = './inference/xcodec_mini_infer'
34
 
35
  # Create the folder if it doesn't exist
36
  if not os.path.exists(folder_path):
@@ -41,22 +41,347 @@ else:
41
 
42
  snapshot_download(
43
  repo_id = "m-a-p/xcodec_mini_infer",
44
- local_dir = "./inference/xcodec_mini_infer"
45
  )
46
 
47
- # Change to the "inference" directory
48
- inference_dir = "./inference"
49
- try:
50
- os.chdir(inference_dir)
51
- print(f"Changed working directory to: {os.getcwd()}")
52
- except FileNotFoundError:
53
- print(f"Directory not found: {inference_dir}")
54
- exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def empty_output_folder(output_dir):
57
  # List all files in the output directory
58
  files = os.listdir(output_dir)
59
-
60
  # Iterate over the files and remove them
61
  for file in files:
62
  file_path = os.path.join(output_dir, file)
@@ -70,54 +395,8 @@ def empty_output_folder(output_dir):
70
  except Exception as e:
71
  print(f"Error deleting file {file_path}: {e}")
72
 
73
- # Function to create a temporary file with string content
74
- def create_temp_file(content, prefix, suffix=".txt"):
75
- temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
76
- # Ensure content ends with newline and normalize line endings
77
- content = content.strip() + "\n\n" # Add extra newline at end
78
- content = content.replace("\r\n", "\n").replace("\r", "\n")
79
- temp_file.write(content)
80
- temp_file.close()
81
-
82
- # Debug: Print file contents
83
- print(f"\nContent written to {prefix}{suffix}:")
84
- print(content)
85
- print("---")
86
-
87
- return temp_file.name
88
-
89
- def get_last_mp3_file(output_dir):
90
- # List all files in the output directory
91
- files = os.listdir(output_dir)
92
-
93
- # Filter only .mp3 files
94
- mp3_files = [file for file in files if file.endswith('.mp3')]
95
-
96
- if not mp3_files:
97
- print("No .mp3 files found in the output folder.")
98
- return None
99
-
100
- # Get the full path for the mp3 files
101
- mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
102
-
103
- # Sort the files based on the modification time (most recent first)
104
- mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
105
-
106
- # Return the most recent .mp3 file
107
- return mp3_files_with_path[0]
108
-
109
- device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
110
-
111
- model = AutoModelForCausalLM.from_pretrained(
112
- "m-a-p/YuE-s1-7B-anneal-en-cot",
113
- torch_dtype=torch.float16,
114
- attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
115
- )
116
- model.to(device)
117
- model.eval()
118
-
119
  @spaces.GPU(duration=120)
120
- def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200):
121
 
122
  # Ensure the output folder exists
123
  output_dir = "./output"
@@ -125,55 +404,17 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=
125
  print(f"Output folder ensured at: {output_dir}")
126
 
127
  empty_output_folder(output_dir)
128
-
129
- # Command and arguments with optimized settings
130
- command = [
131
- "python", "infer.py",
132
- "--stage1_model", model,
133
- # "--stage2_model", "m-a-p/YuE-s2-1B-general",
134
- "--genre_txt", f"{genre_txt_content}",
135
- "--lyrics_txt", f"{lyrics_txt_content}",
136
- "--run_n_segments", f"{num_segments}",
137
- # "--stage2_batch_size", "4",
138
- "--output_dir", f"{output_dir}",
139
- "--cuda_idx", "0",
140
- "--max_new_tokens", f"{max_new_tokens}",
141
- # "--disable_offload_model"
142
- ]
143
-
144
- # Set up environment variables for CUDA with optimized settings
145
- env = os.environ.copy()
146
-
147
- # Execute the command
148
- try:
149
- subprocess.run(command, check=True, env=env)
150
- print("Command executed successfully!")
151
-
152
- # Check and print the contents of the output folder
153
- output_files = os.listdir(output_dir)
154
- if output_files:
155
- print("Output folder contents:")
156
- for file in output_files:
157
- print(f"- {file}")
158
-
159
- last_mp3 = get_last_mp3_file(output_dir)
160
-
161
- if last_mp3:
162
- print("Last .mp3 file:", last_mp3)
163
- return last_mp3
164
- else:
165
- return None
166
- else:
167
- print("Output folder is empty.")
168
- return None
169
- except subprocess.CalledProcessError as e:
170
- print(f"Error occurred: {e}")
171
  return None
172
- finally:
173
- # Clean up temporary files
174
- print("Temporary files deleted.")
175
 
176
- # Gradio
177
 
178
  with gr.Blocks() as demo:
179
  with gr.Column():
@@ -182,7 +423,7 @@ with gr.Blocks() as demo:
182
  <div style="display:flex;column-gap:4px;">
183
  <a href="https://github.com/multimodal-art-projection/YuE">
184
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
185
- </a>
186
  <a href="https://map-yue.github.io">
187
  <img src='https://img.shields.io/badge/Project-Page-green'>
188
  </a>
@@ -195,7 +436,7 @@ with gr.Blocks() as demo:
195
  with gr.Column():
196
  genre_txt = gr.Textbox(label="Genre")
197
  lyrics_txt = gr.Textbox(label="Lyrics")
198
-
199
  with gr.Column():
200
  if is_shared_ui:
201
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
@@ -242,16 +483,16 @@ Through the highs and lows, I'mma keep it real
242
  Living out my dreams with this mic and a deal
243
  """
244
  ]
245
- ],
246
  inputs = [genre_txt, lyrics_txt],
247
  outputs = [music_out],
248
  cache_examples = False,
249
  # cache_mode="lazy",
250
- fn=infer
251
  )
252
-
253
  submit_btn.click(
254
- fn = infer,
255
  inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
256
  outputs = [music_out]
257
  )
 
1
  import gradio as gr
2
  import subprocess
3
+ import os
4
  import shutil
5
  import tempfile
6
  import spaces
 
27
  # Install flash-attn
28
  install_flash_attn()
29
 
30
+ from huggingface_hub import snapshot_download
31
 
32
  # Create xcodec_mini_infer folder
33
+ folder_path = './xcodec_mini_infer'
34
 
35
  # Create the folder if it doesn't exist
36
  if not os.path.exists(folder_path):
 
41
 
42
  snapshot_download(
43
  repo_id = "m-a-p/xcodec_mini_infer",
44
+ local_dir = "./xcodec_mini_infer"
45
  )
46
 
47
+ # Add xcodec_mini_infer and descriptaudiocodec to sys path
48
+ import sys
49
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
50
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
51
+
52
+ import argparse
53
+ import numpy as np
54
+ import json
55
+ from omegaconf import OmegaConf
56
+ import torchaudio
57
+ from torchaudio.transforms import Resample
58
+ import soundfile as sf
59
+
60
+ import uuid
61
+ from tqdm import tqdm
62
+ from einops import rearrange
63
+ from codecmanipulator import CodecManipulator
64
+ from mmtokenizer import _MMSentencePieceTokenizer
65
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
66
+ import glob
67
+ import time
68
+ import copy
69
+ from collections import Counter
70
+ from models.soundstream_hubert_new import SoundStream
71
+ from vocoder import build_codec_model, process_audio
72
+ from post_process_audio import replace_low_freq_with_energy_matched
73
+ import re
74
+
75
+
76
+ # --- Arguments and Model Loading from infer.py ---
77
+ parser = argparse.ArgumentParser()
78
+ # Model Configuration:
79
+ parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
80
+ parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
81
+ parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
82
+ # Prompt
83
+ parser.add_argument("--genre_txt", type=str, default="", help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.") # Modified: removed required=True and using default=""
84
+ parser.add_argument("--lyrics_txt", type=str, default="", help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.") # Modified: removed required=True and using default=""
85
+ parser.add_argument("--use_audio_prompt", action="store_true", help="If set, the model will use an audio file as a prompt during generation. The audio file should be specified using --audio_prompt_path.")
86
+ parser.add_argument("--audio_prompt_path", type=str, default="", help="The file path to an audio file to use as a reference prompt when --use_audio_prompt is enabled.")
87
+ parser.add_argument("--prompt_start_time", type=float, default=0.0, help="The start time in seconds to extract the audio prompt from the given audio file.")
88
+ parser.add_argument("--prompt_end_time", type=float, default=30.0, help="The end time in seconds to extract the audio prompt from the given audio file.")
89
+ # Output
90
+ parser.add_argument("--output_dir", type=str, default="./output", help="The directory where generated outputs will be saved.")
91
+ parser.add_argument("--keep_intermediate", action="store_true", help="If set, intermediate outputs will be saved during processing.")
92
+ parser.add_argument("--disable_offload_model", action="store_true", help="If set, the model will not be offloaded from the GPU to CPU after Stage 1 inference.")
93
+ parser.add_argument("--cuda_idx", type=int, default=0)
94
+ # Config for xcodec and upsampler
95
+ parser.add_argument('--basic_model_config', default='./xcodec_mini_infer/final_ckpt/config.yaml', help='YAML files for xcodec configurations.')
96
+ parser.add_argument('--resume_path', default='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', help='Path to the xcodec checkpoint.')
97
+ parser.add_argument('--config_path', type=str, default='./xcodec_mini_infer/decoders/config.yaml', help='Path to Vocos config file.')
98
+ parser.add_argument('--vocal_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
99
+ parser.add_argument('--inst_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
100
+ parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
101
+
102
+
103
+ args = parser.parse_args([]) # Modified: Pass empty list to parse_args to avoid command line parsing in Gradio
104
+
105
+ if args.use_audio_prompt and not args.audio_prompt_path:
106
+ raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
107
+ model_name = args.stage1_model # Modified: Renamed 'model' to 'model_name' to avoid shadowing the loaded model later
108
+ cuda_idx = args.cuda_idx
109
+ max_new_tokens_config = args.max_new_tokens # Modified: Renamed 'max_new_tokens' to 'max_new_tokens_config' to avoid shadowing the Gradio input
110
+ stage1_output_dir = os.path.join(args.output_dir, f"stage1")
111
+ os.makedirs(stage1_output_dir, exist_ok=True)
112
+
113
+ # load tokenizer and model
114
+ device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
115
+
116
+ # Now you can use `device` to move your tensors or models to the GPU (if available)
117
+ print(f"Using device: {device}")
118
+
119
+ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
120
+
121
+ codectool = CodecManipulator("xcodec", 0, 1)
122
+ model_config = OmegaConf.load(args.basic_model_config)
123
+ codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
124
+ parameter_dict = torch.load(args.resume_path, map_location='cpu')
125
+ codec_model.load_state_dict(parameter_dict['codec_model'])
126
+ codec_model.to(device)
127
+ codec_model.eval()
128
+
129
+ class BlockTokenRangeProcessor(LogitsProcessor):
130
+ def __init__(self, start_id, end_id):
131
+ self.blocked_token_ids = list(range(start_id, end_id))
132
+
133
+ def __call__(self, input_ids, scores):
134
+ scores[:, self.blocked_token_ids] = -float("inf")
135
+ return scores
136
+
137
+ def load_audio_mono(filepath, sampling_rate=16000):
138
+ audio, sr = torchaudio.load(filepath)
139
+ # Convert to mono
140
+ audio = torch.mean(audio, dim=0, keepdim=True)
141
+ # Resample if needed
142
+ if sr != sampling_rate:
143
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
144
+ audio = resampler(audio)
145
+ return audio
146
+
147
+ def split_lyrics(lyrics):
148
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
149
+ segments = re.findall(pattern, lyrics, re.DOTALL)
150
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
151
+ return structured_lyrics
152
+
153
+ def generate_music(genres, lyrics_content, num_segments_run, max_new_tokens_run): # Modified: Function to encapsulate generation logic
154
+ stage1_output_set_local = [] # Modified: Local variable to store output paths
155
+
156
+ lyrics = split_lyrics(lyrics_content)
157
+ # intruction
158
+ full_lyrics = "\n".join(lyrics)
159
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
160
+ prompt_texts += lyrics
161
+
162
+ random_id = uuid.uuid4()
163
+ output_seq = None
164
+
165
+ # Here is suggested decoding config
166
+ top_p = 0.93
167
+ temperature = 1.0
168
+ repetition_penalty = 1.2
169
+ # special tokens
170
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
171
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
172
+
173
+ raw_output = None
174
+
175
+ # Format text prompt
176
+ run_n_segments = min(num_segments_run+1, len(lyrics)) # Modified: Use passed num_segments_run
177
+
178
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
179
+
180
+ global model # Modified: Declare model as global to use the loaded model in Gradio scope
181
+
182
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
183
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
184
+ guidance_scale = 1.5 if i <=1 else 1.2
185
+ if i==0:
186
+ continue
187
+ if i==1:
188
+ if args.use_audio_prompt:
189
+ audio_prompt = load_audio_mono(args.audio_prompt_path)
190
+ audio_prompt.unsqueeze_(0)
191
+ with torch.no_grad():
192
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
193
+ raw_codes = raw_codes.transpose(0, 1)
194
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
195
+ # Format audio prompt
196
+ code_ids = codectool.npy2ids(raw_codes[0])
197
+ audio_prompt_codec = code_ids[int(args.prompt_start_time *50): int(args.prompt_end_time *50)] # 50 is tps of xcodec
198
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
199
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
200
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
201
+ else:
202
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
203
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
204
+ else:
205
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
206
+
207
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
208
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
209
+ # Use window slicing in case output sequence exceeds the context of model
210
+ max_context = 16384-max_new_tokens_config-1 # Modified: Use max_new_tokens_config
211
+ if input_ids.shape[-1] > max_context:
212
+ print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
213
+ input_ids = input_ids[:, -(max_context):]
214
+ with torch.no_grad():
215
+ output_seq = model.generate(
216
+ input_ids=input_ids,
217
+ max_new_tokens=max_new_tokens_run, # Modified: Use max_new_tokens_run
218
+ min_new_tokens=100,
219
+ do_sample=True,
220
+ top_p=top_p,
221
+ temperature=temperature,
222
+ repetition_penalty=repetition_penalty,
223
+ eos_token_id=mmtokenizer.eoa,
224
+ pad_token_id=mmtokenizer.eoa,
225
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
226
+ guidance_scale=guidance_scale,
227
+ )
228
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
229
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
230
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
231
+ if i > 1:
232
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
233
+ else:
234
+ raw_output = output_seq
235
+ print(len(raw_output))
236
+
237
+ # save raw output and check sanity
238
+ ids = raw_output[0].cpu().numpy()
239
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
240
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
241
+ if len(soa_idx)!=len(eoa_idx):
242
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
243
+
244
+ vocals = []
245
+ instrumentals = []
246
+ range_begin = 1 if args.use_audio_prompt else 0
247
+ for i in range(range_begin, len(soa_idx)):
248
+ codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
249
+ if codec_ids[0] == 32016:
250
+ codec_ids = codec_ids[1:]
251
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
252
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
253
+ vocals.append(vocals_ids)
254
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
255
+ instrumentals.append(instrumentals_ids)
256
+ vocals = np.concatenate(vocals, axis=1)
257
+ instrumentals = np.concatenate(instrumentals, axis=1)
258
+ vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens_run}_vocal_{random_id}".replace('.', '@')+'.npy') # Modified: Use max_new_tokens_run in filename
259
+ inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens_run}_instrumental_{random_id}".replace('.', '@')+'.npy') # Modified: Use max_new_tokens_run in filename
260
+ np.save(vocal_save_path, vocals)
261
+ np.save(inst_save_path, instrumentals)
262
+ stage1_output_set_local.append(vocal_save_path)
263
+ stage1_output_set_local.append(inst_save_path)
264
+
265
+
266
+ # offload model - Removed offloading for gradio integration to keep model loaded
267
+ # if not args.disable_offload_model:
268
+ # model.cpu()
269
+ # del model
270
+ # torch.cuda.empty_cache()
271
+
272
+ print("Converting to Audio...")
273
+
274
+ # convert audio tokens to audio
275
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
276
+ folder_path = os.path.dirname(path)
277
+ if not os.path.exists(folder_path):
278
+ os.makedirs(folder_path)
279
+ limit = 0.99
280
+ max_val = wav.abs().max()
281
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
282
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
283
+ # reconstruct tracks
284
+ recons_output_dir = os.path.join(args.output_dir, "recons")
285
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
286
+ os.makedirs(recons_mix_dir, exist_ok=True)
287
+ tracks = []
288
+ for npy in stage1_output_set_local: # Modified: Use stage1_output_set_local
289
+ codec_result = np.load(npy)
290
+ decodec_rlt=[]
291
+ with torch.no_grad():
292
+ decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
293
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
294
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
295
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
296
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
297
+ tracks.append(save_path)
298
+ save_audio(decodec_rlt, save_path, 16000)
299
+ # mix tracks
300
+ for inst_path in tracks:
301
+ try:
302
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
303
+ and 'instrumental' in inst_path:
304
+ # find pair
305
+ vocal_path = inst_path.replace('instrumental', 'vocal')
306
+ if not os.path.exists(vocal_path):
307
+ continue
308
+ # mix
309
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
310
+ vocal_stem, sr = sf.read(inst_path)
311
+ instrumental_stem, _ = sf.read(vocal_path)
312
+ mix_stem = (vocal_stem + instrumental_stem) / 1
313
+ sf.write(recons_mix, mix_stem, sr)
314
+ except Exception as e:
315
+ print(e)
316
+
317
+ # vocoder to upsample audios
318
+ vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
319
+ vocoder_output_dir = os.path.join(args.output_dir, 'vocoder')
320
+ vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
321
+ vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
322
+ os.makedirs(vocoder_mix_dir, exist_ok=True)
323
+ os.makedirs(vocoder_stems_dir, exist_ok=True)
324
+
325
+ instrumental_output = None # Initialize outside try block
326
+ vocal_output = None # Initialize outside try block
327
+ recons_mix_path = "" # Initialize outside try block
328
+
329
+
330
+ for npy in stage1_output_set_local: # Modified: Use stage1_output_set_local
331
+ if 'instrumental' in npy:
332
+ # Process instrumental
333
+ instrumental_output = process_audio(
334
+ npy,
335
+ os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
336
+ args.rescale,
337
+ args,
338
+ inst_decoder,
339
+ codec_model
340
+ )
341
+ else:
342
+ # Process vocal
343
+ vocal_output = process_audio(
344
+ npy,
345
+ os.path.join(vocoder_stems_dir, 'vocal.mp3'),
346
+ args.rescale,
347
+ args,
348
+ vocal_decoder,
349
+ codec_model
350
+ )
351
+ # mix tracks
352
+ try:
353
+ mix_output = instrumental_output + vocal_output
354
+ recons_mix_path_temp = os.path.join(recons_mix_dir, os.path.basename(recons_mix)) # Use recons_mix from previous step
355
+ save_audio(mix_output, recons_mix_path_temp, 44100, args.rescale)
356
+ print(f"Created mix: {recons_mix_path_temp}")
357
+ recons_mix_path = recons_mix_path_temp # Assign to outer scope variable
358
+ except RuntimeError as e:
359
+ print(e)
360
+ print(f"mix {recons_mix_path} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
361
+
362
+ # Post process
363
+ final_output_path = os.path.join(args.output_dir, os.path.basename(recons_mix_path)) # Use recons_mix_path from previous step
364
+ replace_low_freq_with_energy_matched(
365
+ a_file=recons_mix_path, # 16kHz # Use recons_mix_path
366
+ b_file=recons_mix_path_temp, # 48kHz # Use recons_mix_path_temp
367
+ c_file=final_output_path,
368
+ cutoff_freq=5500.0
369
+ )
370
+ print("All process Done")
371
+ return final_output_path # Modified: Return the final output audio path
372
+
373
+
374
+ # Gradio UI
375
+ model = AutoModelForCausalLM.from_pretrained( # Load model here for Gradio scope
376
+ "m-a-p/YuE-s1-7B-anneal-en-cot",
377
+ torch_dtype=torch.float16,
378
+ attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
379
+ ).to(device).eval() # Modified: Load model globally for Gradio to access
380
 
381
  def empty_output_folder(output_dir):
382
  # List all files in the output directory
383
  files = os.listdir(output_dir)
384
+
385
  # Iterate over the files and remove them
386
  for file in files:
387
  file_path = os.path.join(output_dir, file)
 
395
  except Exception as e:
396
  print(f"Error deleting file {file_path}: {e}")
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  @spaces.GPU(duration=120)
399
+ def infer_gradio(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=200): # Modified: Renamed infer to infer_gradio to avoid conflict
400
 
401
  # Ensure the output folder exists
402
  output_dir = "./output"
 
404
  print(f"Output folder ensured at: {output_dir}")
405
 
406
  empty_output_folder(output_dir)
407
+
408
+ # Call the generation function directly
409
+ output_audio_path = generate_music(genre_txt_content, lyrics_txt_content, int(num_segments), int(max_new_tokens)) # Modified: Call generate_music and pass num_segments and max_new_tokens as int
410
+
411
+ if output_audio_path and os.path.exists(output_audio_path):
412
+ print("Generated audio file:", output_audio_path)
413
+ return output_audio_path
414
+ else:
415
+ print("No audio file generated or path is invalid.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  return None
 
 
 
417
 
 
418
 
419
  with gr.Blocks() as demo:
420
  with gr.Column():
 
423
  <div style="display:flex;column-gap:4px;">
424
  <a href="https://github.com/multimodal-art-projection/YuE">
425
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
426
+ </a>
427
  <a href="https://map-yue.github.io">
428
  <img src='https://img.shields.io/badge/Project-Page-green'>
429
  </a>
 
436
  with gr.Column():
437
  genre_txt = gr.Textbox(label="Genre")
438
  lyrics_txt = gr.Textbox(label="Lyrics")
439
+
440
  with gr.Column():
441
  if is_shared_ui:
442
  num_segments = gr.Number(label="Number of Segments", value=2, interactive=True)
 
483
  Living out my dreams with this mic and a deal
484
  """
485
  ]
486
+ ],
487
  inputs = [genre_txt, lyrics_txt],
488
  outputs = [music_out],
489
  cache_examples = False,
490
  # cache_mode="lazy",
491
+ fn=infer_gradio # Modified: Use infer_gradio
492
  )
493
+
494
  submit_btn.click(
495
+ fn = infer_gradio, # Modified: Use infer_gradio
496
  inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
497
  outputs = [music_out]
498
  )