Pancake_HFv1 / demucs /train.py
r3gm's picture
Upload 288 files
7bc29af
raw
history blame
4.24 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from .utils import apply_model, average_metric, center_trim
def train_model(epoch,
dataset,
model,
criterion,
optimizer,
augment,
quantizer=None,
diffq=0,
repeat=1,
device="cpu",
seed=None,
workers=4,
world_size=1,
batch_size=16):
if world_size > 1:
sampler = DistributedSampler(dataset)
sampler_epoch = epoch * repeat
if seed is not None:
sampler_epoch += seed * 1000
sampler.set_epoch(sampler_epoch)
batch_size //= world_size
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=workers)
else:
loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True)
current_loss = 0
model_size = 0
for repetition in range(repeat):
tq = tqdm.tqdm(loader,
ncols=120,
desc=f"[{epoch:03d}] train ({repetition + 1}/{repeat})",
leave=False,
file=sys.stdout,
unit=" batch")
total_loss = 0
for idx, sources in enumerate(tq):
if len(sources) < batch_size:
# skip uncomplete batch for augment.Remix to work properly
continue
sources = sources.to(device)
sources = augment(sources)
mix = sources.sum(dim=1)
estimates = model(mix)
sources = center_trim(sources, estimates)
loss = criterion(estimates, sources)
model_size = 0
if quantizer is not None:
model_size = quantizer.model_size()
train_loss = loss + diffq * model_size
train_loss.backward()
grad_norm = 0
for p in model.parameters():
if p.grad is not None:
grad_norm += p.grad.data.norm()**2
grad_norm = grad_norm**0.5
optimizer.step()
optimizer.zero_grad()
if quantizer is not None:
model_size = model_size.item()
total_loss += loss.item()
current_loss = total_loss / (1 + idx)
tq.set_postfix(loss=f"{current_loss:.4f}", ms=f"{model_size:.2f}",
grad=f"{grad_norm:.5f}")
# free some space before next round
del sources, mix, estimates, loss, train_loss
if world_size > 1:
sampler.epoch += 1
if world_size > 1:
current_loss = average_metric(current_loss)
return current_loss, model_size
def validate_model(epoch,
dataset,
model,
criterion,
device="cpu",
rank=0,
world_size=1,
shifts=0,
overlap=0.25,
split=False):
indexes = range(rank, len(dataset), world_size)
tq = tqdm.tqdm(indexes,
ncols=120,
desc=f"[{epoch:03d}] valid",
leave=False,
file=sys.stdout,
unit=" track")
current_loss = 0
for index in tq:
streams = dataset[index]
# first five minutes to avoid OOM on --upsample models
streams = streams[..., :15_000_000]
streams = streams.to(device)
sources = streams[1:]
mix = streams[0]
estimates = apply_model(model, mix, shifts=shifts, split=split, overlap=overlap)
loss = criterion(estimates, sources)
current_loss += loss.item() / len(indexes)
del estimates, streams, sources
if world_size > 1:
current_loss = average_metric(current_loss, len(indexes))
return current_loss