|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
from contextlib import contextmanager |
|
import math |
|
import os |
|
import tempfile |
|
import typing as tp |
|
|
|
import errno |
|
import functools |
|
import hashlib |
|
import inspect |
|
import io |
|
import os |
|
import random |
|
import socket |
|
import tempfile |
|
import warnings |
|
import zlib |
|
import tkinter as tk |
|
|
|
from diffq import UniformQuantizer, DiffQuantizer |
|
import torch as th |
|
import tqdm |
|
from torch import distributed |
|
from torch.nn import functional as F |
|
|
|
import torch |
|
|
|
def unfold(a, kernel_size, stride): |
|
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K] |
|
with K the kernel size, by extracting frames with the given stride. |
|
|
|
This will pad the input so that `F = ceil(T / K)`. |
|
|
|
see https://github.com/pytorch/pytorch/issues/60466 |
|
""" |
|
*shape, length = a.shape |
|
n_frames = math.ceil(length / stride) |
|
tgt_length = (n_frames - 1) * stride + kernel_size |
|
a = F.pad(a, (0, tgt_length - length)) |
|
strides = list(a.stride()) |
|
assert strides[-1] == 1, 'data should be contiguous' |
|
strides = strides[:-1] + [stride, 1] |
|
return a.as_strided([*shape, n_frames, kernel_size], strides) |
|
|
|
|
|
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): |
|
""" |
|
Center trim `tensor` with respect to `reference`, along the last dimension. |
|
`reference` can also be a number, representing the length to trim to. |
|
If the size difference != 0 mod 2, the extra sample is removed on the right side. |
|
""" |
|
ref_size: int |
|
if isinstance(reference, torch.Tensor): |
|
ref_size = reference.size(-1) |
|
else: |
|
ref_size = reference |
|
delta = tensor.size(-1) - ref_size |
|
if delta < 0: |
|
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") |
|
if delta: |
|
tensor = tensor[..., delta // 2:-(delta - delta // 2)] |
|
return tensor |
|
|
|
|
|
def pull_metric(history: tp.List[dict], name: str): |
|
out = [] |
|
for metrics in history: |
|
metric = metrics |
|
for part in name.split("."): |
|
metric = metric[part] |
|
out.append(metric) |
|
return out |
|
|
|
|
|
def EMA(beta: float = 1): |
|
""" |
|
Exponential Moving Average callback. |
|
Returns a single function that can be called to repeatidly update the EMA |
|
with a dict of metrics. The callback will return |
|
the new averaged dict of metrics. |
|
|
|
Note that for `beta=1`, this is just plain averaging. |
|
""" |
|
fix: tp.Dict[str, float] = defaultdict(float) |
|
total: tp.Dict[str, float] = defaultdict(float) |
|
|
|
def _update(metrics: dict, weight: float = 1) -> dict: |
|
nonlocal total, fix |
|
for key, value in metrics.items(): |
|
total[key] = total[key] * beta + weight * float(value) |
|
fix[key] = fix[key] * beta + weight |
|
return {key: tot / fix[key] for key, tot in total.items()} |
|
return _update |
|
|
|
|
|
def sizeof_fmt(num: float, suffix: str = 'B'): |
|
""" |
|
Given `num` bytes, return human readable size. |
|
Taken from https://stackoverflow.com/a/1094933 |
|
""" |
|
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: |
|
if abs(num) < 1024.0: |
|
return "%3.1f%s%s" % (num, unit, suffix) |
|
num /= 1024.0 |
|
return "%.1f%s%s" % (num, 'Yi', suffix) |
|
|
|
|
|
@contextmanager |
|
def temp_filenames(count: int, delete=True): |
|
names = [] |
|
try: |
|
for _ in range(count): |
|
names.append(tempfile.NamedTemporaryFile(delete=False).name) |
|
yield names |
|
finally: |
|
if delete: |
|
for name in names: |
|
os.unlink(name) |
|
|
|
def average_metric(metric, count=1.): |
|
""" |
|
Average `metric` which should be a float across all hosts. `count` should be |
|
the weight for this particular host (i.e. number of examples). |
|
""" |
|
metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') |
|
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) |
|
return metric[1].item() / metric[0].item() |
|
|
|
|
|
def free_port(host='', low=20000, high=40000): |
|
""" |
|
Return a port number that is most likely free. |
|
This could suffer from a race condition although |
|
it should be quite rare. |
|
""" |
|
sock = socket.socket() |
|
while True: |
|
port = random.randint(low, high) |
|
try: |
|
sock.bind((host, port)) |
|
except OSError as error: |
|
if error.errno == errno.EADDRINUSE: |
|
continue |
|
raise |
|
return port |
|
|
|
|
|
def sizeof_fmt(num, suffix='B'): |
|
""" |
|
Given `num` bytes, return human readable size. |
|
Taken from https://stackoverflow.com/a/1094933 |
|
""" |
|
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: |
|
if abs(num) < 1024.0: |
|
return "%3.1f%s%s" % (num, unit, suffix) |
|
num /= 1024.0 |
|
return "%.1f%s%s" % (num, 'Yi', suffix) |
|
|
|
|
|
def human_seconds(seconds, display='.2f'): |
|
""" |
|
Given `seconds` seconds, return human readable duration. |
|
""" |
|
value = seconds * 1e6 |
|
ratios = [1e3, 1e3, 60, 60, 24] |
|
names = ['us', 'ms', 's', 'min', 'hrs', 'days'] |
|
last = names.pop(0) |
|
for name, ratio in zip(names, ratios): |
|
if value / ratio < 0.3: |
|
break |
|
value /= ratio |
|
last = name |
|
return f"{format(value, display)} {last}" |
|
|
|
|
|
class TensorChunk: |
|
def __init__(self, tensor, offset=0, length=None): |
|
total_length = tensor.shape[-1] |
|
assert offset >= 0 |
|
assert offset < total_length |
|
|
|
if length is None: |
|
length = total_length - offset |
|
else: |
|
length = min(total_length - offset, length) |
|
|
|
self.tensor = tensor |
|
self.offset = offset |
|
self.length = length |
|
self.device = tensor.device |
|
|
|
@property |
|
def shape(self): |
|
shape = list(self.tensor.shape) |
|
shape[-1] = self.length |
|
return shape |
|
|
|
def padded(self, target_length): |
|
delta = target_length - self.length |
|
total_length = self.tensor.shape[-1] |
|
assert delta >= 0 |
|
|
|
start = self.offset - delta // 2 |
|
end = start + target_length |
|
|
|
correct_start = max(0, start) |
|
correct_end = min(total_length, end) |
|
|
|
pad_left = correct_start - start |
|
pad_right = end - correct_end |
|
|
|
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) |
|
assert out.shape[-1] == target_length |
|
return out |
|
|
|
|
|
def tensor_chunk(tensor_or_chunk): |
|
if isinstance(tensor_or_chunk, TensorChunk): |
|
return tensor_or_chunk |
|
else: |
|
assert isinstance(tensor_or_chunk, th.Tensor) |
|
return TensorChunk(tensor_or_chunk) |
|
|
|
|
|
def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None): |
|
""" |
|
Apply model to a given mixture. |
|
|
|
Args: |
|
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec |
|
and apply the oppositve shift to the output. This is repeated `shifts` time and |
|
all predictions are averaged. This effectively makes the model time equivariant |
|
and improves SDR by up to 0.2 points. |
|
split (bool): if True, the input will be broken down in 8 seconds extracts |
|
and predictions will be performed individually on each and concatenated. |
|
Useful for model with large memory footprint like Tasnet. |
|
progress (bool): if True, show a progress bar (requires split=True) |
|
""" |
|
|
|
channels, length = mix.size() |
|
device = mix.device |
|
progress_value = 0 |
|
|
|
if split: |
|
out = th.zeros(4, channels, length, device=device) |
|
shift = model.samplerate * 10 |
|
offsets = range(0, length, shift) |
|
scale = 10 |
|
if progress: |
|
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') |
|
for offset in offsets: |
|
chunk = mix[..., offset:offset + shift] |
|
if set_progress_bar: |
|
progress_value += 1 |
|
set_progress_bar(0.1, (0.8/len(offsets)*progress_value)) |
|
chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar) |
|
else: |
|
chunk_out = apply_model_v1(model, chunk, shifts=shifts) |
|
out[..., offset:offset + shift] = chunk_out |
|
offset += shift |
|
return out |
|
elif shifts: |
|
max_shift = int(model.samplerate / 2) |
|
mix = F.pad(mix, (max_shift, max_shift)) |
|
offsets = list(range(max_shift)) |
|
random.shuffle(offsets) |
|
out = 0 |
|
for offset in offsets[:shifts]: |
|
shifted = mix[..., offset:offset + length + max_shift] |
|
if set_progress_bar: |
|
shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar) |
|
else: |
|
shifted_out = apply_model_v1(model, shifted) |
|
out += shifted_out[..., max_shift - offset:max_shift - offset + length] |
|
out /= shifts |
|
return out |
|
else: |
|
valid_length = model.valid_length(length) |
|
delta = valid_length - length |
|
padded = F.pad(mix, (delta // 2, delta - delta // 2)) |
|
with th.no_grad(): |
|
out = model(padded.unsqueeze(0))[0] |
|
return center_trim(out, mix) |
|
|
|
def apply_model_v2(model, mix, shifts=None, split=False, |
|
overlap=0.25, transition_power=1., progress=False, set_progress_bar=None): |
|
""" |
|
Apply model to a given mixture. |
|
|
|
Args: |
|
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec |
|
and apply the oppositve shift to the output. This is repeated `shifts` time and |
|
all predictions are averaged. This effectively makes the model time equivariant |
|
and improves SDR by up to 0.2 points. |
|
split (bool): if True, the input will be broken down in 8 seconds extracts |
|
and predictions will be performed individually on each and concatenated. |
|
Useful for model with large memory footprint like Tasnet. |
|
progress (bool): if True, show a progress bar (requires split=True) |
|
""" |
|
|
|
assert transition_power >= 1, "transition_power < 1 leads to weird behavior." |
|
device = mix.device |
|
channels, length = mix.shape |
|
progress_value = 0 |
|
|
|
if split: |
|
out = th.zeros(len(model.sources), channels, length, device=device) |
|
sum_weight = th.zeros(length, device=device) |
|
segment = model.segment_length |
|
stride = int((1 - overlap) * segment) |
|
offsets = range(0, length, stride) |
|
scale = stride / model.samplerate |
|
if progress: |
|
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') |
|
|
|
|
|
|
|
weight = th.cat([th.arange(1, segment // 2 + 1), |
|
th.arange(segment - segment // 2, 0, -1)]).to(device) |
|
assert len(weight) == segment |
|
|
|
|
|
weight = (weight / weight.max())**transition_power |
|
for offset in offsets: |
|
chunk = TensorChunk(mix, offset, segment) |
|
if set_progress_bar: |
|
progress_value += 1 |
|
set_progress_bar(0.1, (0.8/len(offsets)*progress_value)) |
|
chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar) |
|
else: |
|
chunk_out = apply_model_v2(model, chunk, shifts=shifts) |
|
chunk_length = chunk_out.shape[-1] |
|
out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out |
|
sum_weight[offset:offset + segment] += weight[:chunk_length] |
|
offset += segment |
|
assert sum_weight.min() > 0 |
|
out /= sum_weight |
|
return out |
|
elif shifts: |
|
max_shift = int(0.5 * model.samplerate) |
|
mix = tensor_chunk(mix) |
|
padded_mix = mix.padded(length + 2 * max_shift) |
|
out = 0 |
|
for _ in range(shifts): |
|
offset = random.randint(0, max_shift) |
|
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) |
|
|
|
if set_progress_bar: |
|
progress_value += 1 |
|
shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar) |
|
else: |
|
shifted_out = apply_model_v2(model, shifted) |
|
out += shifted_out[..., max_shift - offset:] |
|
out /= shifts |
|
return out |
|
else: |
|
valid_length = model.valid_length(length) |
|
mix = tensor_chunk(mix) |
|
padded_mix = mix.padded(valid_length) |
|
with th.no_grad(): |
|
out = model(padded_mix.unsqueeze(0))[0] |
|
return center_trim(out, length) |
|
|
|
|
|
@contextmanager |
|
def temp_filenames(count, delete=True): |
|
names = [] |
|
try: |
|
for _ in range(count): |
|
names.append(tempfile.NamedTemporaryFile(delete=False).name) |
|
yield names |
|
finally: |
|
if delete: |
|
for name in names: |
|
os.unlink(name) |
|
|
|
|
|
def get_quantizer(model, args, optimizer=None): |
|
quantizer = None |
|
if args.diffq: |
|
quantizer = DiffQuantizer( |
|
model, min_size=args.q_min_size, group_size=8) |
|
if optimizer is not None: |
|
quantizer.setup_optimizer(optimizer) |
|
elif args.qat: |
|
quantizer = UniformQuantizer( |
|
model, bits=args.qat, min_size=args.q_min_size) |
|
return quantizer |
|
|
|
|
|
def load_model(path, strict=False): |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
load_from = path |
|
package = th.load(load_from, 'cpu') |
|
|
|
klass = package["klass"] |
|
args = package["args"] |
|
kwargs = package["kwargs"] |
|
|
|
if strict: |
|
model = klass(*args, **kwargs) |
|
else: |
|
sig = inspect.signature(klass) |
|
for key in list(kwargs): |
|
if key not in sig.parameters: |
|
warnings.warn("Dropping inexistant parameter " + key) |
|
del kwargs[key] |
|
model = klass(*args, **kwargs) |
|
|
|
state = package["state"] |
|
training_args = package["training_args"] |
|
quantizer = get_quantizer(model, training_args) |
|
|
|
set_state(model, quantizer, state) |
|
return model |
|
|
|
|
|
def get_state(model, quantizer): |
|
if quantizer is None: |
|
state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} |
|
else: |
|
state = quantizer.get_quantized_state() |
|
buf = io.BytesIO() |
|
th.save(state, buf) |
|
state = {'compressed': zlib.compress(buf.getvalue())} |
|
return state |
|
|
|
|
|
def set_state(model, quantizer, state): |
|
if quantizer is None: |
|
model.load_state_dict(state) |
|
else: |
|
buf = io.BytesIO(zlib.decompress(state["compressed"])) |
|
state = th.load(buf, "cpu") |
|
quantizer.restore_quantized_state(state) |
|
|
|
return state |
|
|
|
|
|
def save_state(state, path): |
|
buf = io.BytesIO() |
|
th.save(state, buf) |
|
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] |
|
|
|
path = path.parent / (path.stem + "-" + sig + path.suffix) |
|
path.write_bytes(buf.getvalue()) |
|
|
|
|
|
def save_model(model, quantizer, training_args, path): |
|
args, kwargs = model._init_args_kwargs |
|
klass = model.__class__ |
|
|
|
state = get_state(model, quantizer) |
|
|
|
save_to = path |
|
package = { |
|
'klass': klass, |
|
'args': args, |
|
'kwargs': kwargs, |
|
'state': state, |
|
'training_args': training_args, |
|
} |
|
th.save(package, save_to) |
|
|
|
|
|
def capture_init(init): |
|
@functools.wraps(init) |
|
def __init__(self, *args, **kwargs): |
|
self._init_args_kwargs = (args, kwargs) |
|
init(self, *args, **kwargs) |
|
|
|
return __init__ |
|
|
|
class DummyPoolExecutor: |
|
class DummyResult: |
|
def __init__(self, func, *args, **kwargs): |
|
self.func = func |
|
self.args = args |
|
self.kwargs = kwargs |
|
|
|
def result(self): |
|
return self.func(*self.args, **self.kwargs) |
|
|
|
def __init__(self, workers=0): |
|
pass |
|
|
|
def submit(self, func, *args, **kwargs): |
|
return DummyPoolExecutor.DummyResult(func, *args, **kwargs) |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, exc_tb): |
|
return |
|
|