|
import torch |
|
from torch.utils.data import DataLoader, Dataset |
|
import torchaudio |
|
import torchvision.transforms as tvt |
|
from denoising_diffusion_pytorch.classifier_free_guidance import Unet, GaussianDiffusion |
|
import glob |
|
import torch.nn as nn |
|
import time, math |
|
from PIL import Image |
|
from diffusers import Mel |
|
import sys |
|
import torchaudio |
|
import librosa |
|
import matplotlib.pyplot as plt |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
args = sys.argv[1:] |
|
|
|
class Audio(Dataset): |
|
def __init__(self, folder): |
|
|
|
self.waveforms = [] |
|
self.labels = [] |
|
print("Loading files...") |
|
for file in glob.iglob(folder + '/**/*.wav', recursive=True): |
|
self.labels.append(int(file.split('/')[-1][0])) |
|
waveform, _ = torchaudio.load(file) |
|
|
|
self.waveforms.append(waveform) |
|
|
|
def __len__(self): |
|
return len(self.waveforms) |
|
|
|
def __getitem__(self, index): |
|
return self.waveforms[index], self.labels[index] |
|
|
|
|
|
image_size = 256 |
|
if len(args) >= 1: |
|
image_size = int(args[0]) |
|
|
|
MEL = Mel(x_res=image_size, y_res=image_size) |
|
img_to_tensor = tvt.PILToTensor() |
|
|
|
def collate(batch): |
|
spectros = [] |
|
labels = [] |
|
for waveform, label in batch: |
|
MEL.load_audio(raw_audio=waveform[0]) |
|
for slice in range(MEL.get_number_of_slices()): |
|
spectro = MEL.audio_slice_to_image(slice) |
|
spectro = img_to_tensor(spectro) / 255.0 |
|
|
|
|
|
|
|
|
|
spectros.append(spectro) |
|
labels.append(label) |
|
|
|
spectros = torch.stack(spectros) |
|
labels = torch.tensor(labels) |
|
|
|
return spectros.to(device), labels.to(device) |
|
|
|
|
|
def initialize(scheduler = None, batch_size=32): |
|
model = Unet( |
|
dim = 64, |
|
num_classes=10, |
|
dim_mults=(1, 2, 4, 8), |
|
channels=1 |
|
) |
|
diffusion = GaussianDiffusion( |
|
model, |
|
image_size=image_size, |
|
timesteps=1000, |
|
loss_type = 'l2', |
|
objective='pred_x0', |
|
|
|
) |
|
diffusion.to(device) |
|
|
|
optim = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8) |
|
if scheduler: |
|
scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-5, max_lr=1e-3, mode="exp_range", cycle_momentum=False) |
|
return diffusion, optim, scheduler |
|
|
|
def timeSince(since): |
|
now = time.time() |
|
s = now - since |
|
m = math.floor(s / 60) |
|
s -= m * 60 |
|
return '%dm %ds' % (m, s) |
|
|
|
start = time.time() |
|
|
|
def train(model, optim, train_dl, batch_size=32, epochs=5, scheduler = None): |
|
size = len(train_dl.dataset) |
|
model.train() |
|
losses = [] |
|
|
|
for e in range(epochs): |
|
batch_loss, batch_counts = 0, 0 |
|
for step, batch in enumerate(train_dl): |
|
model.zero_grad() |
|
batch_counts += 1 |
|
spectros, labels = batch |
|
loss = model(spectros, classes=labels) |
|
|
|
batch_loss += loss.item() |
|
loss.backward() |
|
nn.utils.clip_grad_norm_(model.parameters(), 1) |
|
optim.step() |
|
if scheduler is not None: |
|
scheduler.step() |
|
|
|
if (step % 100 == 0 and step != 0) or (step == len(train_dl) - 1): |
|
to_print = f"{e + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {timeSince(start)} | {step*batch_size:>5d}/{size:>5d}" |
|
print(to_print) |
|
losses.append(batch_loss) |
|
batch_loss, batch_counts = 0, 0 |
|
|
|
labels = torch.randint(0,9,(8, )).to(device) |
|
print(labels) |
|
samples = model.sample(labels) |
|
for i, sample in enumerate(samples): |
|
im = Image.fromarray(sample[0].cpu().numpy() * 255).convert('L') |
|
audio = torch.tensor([MEL.image_to_audio(im)]) |
|
torchaudio.save(f"audio/sample{e}_{i}_{labels[i]}.wav", audio, 48000) |
|
im.save(f"images/sample{e}_{i}_{labels[i]}.jpg") |
|
return losses |
|
|
|
if __name__ == "__main__": |
|
num_epochs = 10 |
|
if len(args) >= 2: |
|
num_epochs = int(args[1]) |
|
|
|
batch_size = 32 |
|
if len(args) >= 3: |
|
batch_size = int(args[2]) |
|
|
|
print(image_size, num_epochs, batch_size) |
|
model, optim, scheduler = initialize(scheduler=True, batch_size=batch_size) |
|
train_data = Audio("AudioMNIST/data") |
|
print("Done Loading") |
|
train_dl = DataLoader(train_data, batch_size, True, collate_fn=collate) |
|
train(model, optim, train_dl, batch_size, num_epochs, scheduler) |
|
torch.save(model.state_dict(), "diffusion_condition_model.pt") |