Spaces:
Runtime error
Runtime error
import sys | |
import time | |
import numpy as np | |
from keras.activations import relu | |
from scipy.io.wavfile import read, write | |
from keras.models import Model, Sequential | |
from keras.layers import Convolution2D, AtrousConvolution2D, Flatten, Dense, \ | |
Input, Lambda, merge | |
def wavenetBlock(n_atrous_filters, atrous_filter_size, atrous_rate, | |
n_conv_filters, conv_filter_size): | |
def f(input_): | |
residual = input_ | |
tanh_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size, | |
atrous_rate=atrous_rate, | |
border_mode='same', | |
activation='tanh')(input_) | |
sigmoid_out = AtrousConvolution1D(n_atrous_filters, atrous_filter_size, | |
atrous_rate=atrous_rate, | |
border_mode='same', | |
activation='sigmoid')(input_) | |
merged = merge([tanh_out, sigmoid_out], mode='mul') | |
skip_out = Convolution1D(1, 1, activation='relu', border_mode='same')(merged) | |
out = merge([skip_out, residual], mode='sum') | |
return out, skip_out | |
return f | |
def get_basic_generative_model(input_size): | |
input = Input(shape=(1, input_size, 1)) | |
l1a, l1b = wavenetBlock(10, 5, 2, 1, 3)(input) | |
l2a, l2b = wavenetBlock(1, 2, 4, 1, 3)(l1a) | |
l3a, l3b = wavenetBlock(1, 2, 8, 1, 3)(l2a) | |
l4a, l4b = wavenetBlock(1, 2, 16, 1, 3)(l3a) | |
l5a, l5b = wavenetBlock(1, 2, 32, 1, 3)(l4a) | |
l6 = merge([l1b, l2b, l3b, l4b, l5b], mode='sum') | |
l7 = Lambda(relu)(l6) | |
l8 = Convolution2D(1, 1, 1, activation='relu')(l7) | |
l9 = Convolution2D(1, 1, 1)(l8) | |
l10 = Flatten()(l9) | |
l11 = Dense(1, activation='tanh')(l10) | |
model = Model(input=input, output=l11) | |
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy']) | |
model.summary() | |
return model | |
def get_audio(filename): | |
sr, audio = read(filename) | |
audio = audio.astype(float) | |
audio = audio - audio.min() | |
audio = audio / (audio.max() - audio.min()) | |
audio = (audio - 0.5) * 2 | |
return sr, audio | |
def frame_generator(sr, audio, frame_size, frame_shift): | |
audio_len = len(audio) | |
while 1: | |
for i in range(0, audio_len - frame_size - 1, frame_shift): | |
frame = audio[i:i+frame_size] | |
if len(frame) < frame_size: | |
break | |
if i + frame_size >= audio_len: | |
break | |
temp = audio[i + frame_size] | |
yield frame.reshape(1, 1, frame_size, 1), \ | |
temp.reshape(1, 1) | |
if __name__ == '__main__': | |
n_epochs = 20 | |
frame_size = 2048 | |
frame_shift = 512 | |
sr_training, training_audio = get_audio('train.wav') | |
training_audio = training_audio[:sr_training*240] | |
sr_valid, valid_audio = get_audio('validate.wav') | |
valid_audio = valid_audio[:sr_valid*30] | |
assert sr_training == sr_valid, "Training, validation samplerate mismatch" | |
n_training_examples = int((len(training_audio)-frame_size-1) / float( | |
frame_shift)) | |
n_validation_examples = int((len(valid_audio)-frame_size-1) / float( | |
frame_shift)) | |
model = get_basic_generative_model(frame_size) | |
print 'Total training examples:', n_training_examples | |
print 'Total validation examples:', n_validation_examples | |
model.fit_generator(frame_generator(sr_training, training_audio, | |
frame_size, frame_shift), | |
samples_per_epoch=n_training_examples, | |
nb_epoch=n_epochs, | |
validation_data=frame_generator(sr_valid, valid_audio, | |
frame_size, frame_shift | |
), | |
nb_val_samples=n_validation_examples, | |
verbose=1) | |
print 'Saving model...' | |
str_timestamp = str(int(time.time())) | |
model.save('models/model_'+str_timestamp+'_'+str(n_epochs)+'.h5') | |
print 'Generating audio...' | |
new_audio = np.zeros((sr_training * 3)) | |
curr_sample_idx = 0 | |
audio_context = valid_audio[:frame_size] | |
while curr_sample_idx < new_audio.shape[0]: | |
predicted_val = model.predict(audio_context.reshape(1, 1, frame_size, | |
1)) | |
ampl_val_16 = predicted_val * 2**15 | |
new_audio[curr_sample_idx] = ampl_val_16 | |
audio_context[-1] = ampl_val_16 | |
audio_context[:-1] = audio_context[1:] | |
pc_str = str(round(100*curr_sample_idx/float(new_audio.shape[0]), 2)) | |
sys.stdout.write('Percent complete: ' + pc_str + '\r') | |
sys.stdout.flush() | |
curr_sample_idx += 1 | |
outfilepath = 'output/reg_generated_'+str_timestamp+'.wav' | |
print 'Writing generated audio to:', outfilepath | |
write(outfilepath, sr_training, new_audio.astype(np.int16)) | |
print '\nDone!' | |