|
|
|
|
|
""" |
|
This implementation is modified from https://github.com/zcaceres/spec_augment |
|
|
|
MIT License |
|
|
|
Copyright (c) 2019 Zach Caceres |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
copies of the Software, and to permit persons to whom the Software is |
|
furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all |
|
copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
LIABILITY, WHETjjHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
SOFTWARE. |
|
""" |
|
|
|
import random |
|
|
|
import torch |
|
|
|
|
|
def specaug( |
|
spec, W=5, F=30, T=40, num_freq_masks=2, num_time_masks=2, replace_with_zero=False |
|
): |
|
"""SpecAugment |
|
|
|
Reference: |
|
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition |
|
(https://arxiv.org/pdf/1904.08779.pdf) |
|
|
|
This implementation modified from https://github.com/zcaceres/spec_augment |
|
|
|
:param torch.Tensor spec: input tensor with the shape (T, dim) |
|
:param int W: time warp parameter |
|
:param int F: maximum width of each freq mask |
|
:param int T: maximum width of each time mask |
|
:param int num_freq_masks: number of frequency masks |
|
:param int num_time_masks: number of time masks |
|
:param bool replace_with_zero: if True, masked parts will be filled with 0, |
|
if False, filled with mean |
|
""" |
|
return time_mask( |
|
freq_mask( |
|
time_warp(spec, W=W), |
|
F=F, |
|
num_masks=num_freq_masks, |
|
replace_with_zero=replace_with_zero, |
|
), |
|
T=T, |
|
num_masks=num_time_masks, |
|
replace_with_zero=replace_with_zero, |
|
) |
|
|
|
|
|
def time_warp(spec, W=5): |
|
"""Time warping |
|
|
|
:param torch.Tensor spec: input tensor with shape (T, dim) |
|
:param int W: time warp parameter |
|
""" |
|
spec = spec.unsqueeze(0) |
|
spec_len = spec.shape[1] |
|
num_rows = spec.shape[2] |
|
device = spec.device |
|
|
|
y = num_rows // 2 |
|
horizontal_line_at_ctr = spec[0, :, y] |
|
assert len(horizontal_line_at_ctr) == spec_len |
|
|
|
point_to_warp = horizontal_line_at_ctr[random.randrange(W, spec_len - W)] |
|
assert isinstance(point_to_warp, torch.Tensor) |
|
|
|
|
|
dist_to_warp = random.randrange(-W, W) |
|
src_pts, dest_pts = ( |
|
torch.tensor([[[point_to_warp, y]]], device=device), |
|
torch.tensor([[[point_to_warp + dist_to_warp, y]]], device=device), |
|
) |
|
warped_spectro, dense_flows = sparse_image_warp(spec, src_pts, dest_pts) |
|
return warped_spectro.squeeze(3).squeeze(0) |
|
|
|
|
|
def freq_mask(spec, F=30, num_masks=1, replace_with_zero=False): |
|
"""Frequency masking |
|
|
|
:param torch.Tensor spec: input tensor with shape (T, dim) |
|
:param int F: maximum width of each mask |
|
:param int num_masks: number of masks |
|
:param bool replace_with_zero: if True, masked parts will be filled with 0, |
|
if False, filled with mean |
|
""" |
|
cloned = spec.unsqueeze(0).clone() |
|
num_mel_channels = cloned.shape[2] |
|
|
|
for i in range(0, num_masks): |
|
f = random.randrange(0, F) |
|
f_zero = random.randrange(0, num_mel_channels - f) |
|
|
|
|
|
if f_zero == f_zero + f: |
|
return cloned.squeeze(0) |
|
|
|
mask_end = random.randrange(f_zero, f_zero + f) |
|
if replace_with_zero: |
|
cloned[0][:, f_zero:mask_end] = 0 |
|
else: |
|
cloned[0][:, f_zero:mask_end] = cloned.mean() |
|
return cloned.squeeze(0) |
|
|
|
|
|
def time_mask(spec, T=40, num_masks=1, replace_with_zero=False): |
|
"""Time masking |
|
|
|
:param torch.Tensor spec: input tensor with shape (T, dim) |
|
:param int T: maximum width of each mask |
|
:param int num_masks: number of masks |
|
:param bool replace_with_zero: if True, masked parts will be filled with 0, |
|
if False, filled with mean |
|
""" |
|
cloned = spec.unsqueeze(0).clone() |
|
len_spectro = cloned.shape[1] |
|
|
|
for i in range(0, num_masks): |
|
t = random.randrange(0, T) |
|
t_zero = random.randrange(0, len_spectro - t) |
|
|
|
|
|
if t_zero == t_zero + t: |
|
return cloned.squeeze(0) |
|
|
|
mask_end = random.randrange(t_zero, t_zero + t) |
|
if replace_with_zero: |
|
cloned[0][t_zero:mask_end, :] = 0 |
|
else: |
|
cloned[0][t_zero:mask_end, :] = cloned.mean() |
|
return cloned.squeeze(0) |
|
|
|
|
|
def sparse_image_warp( |
|
img_tensor, |
|
source_control_point_locations, |
|
dest_control_point_locations, |
|
interpolation_order=2, |
|
regularization_weight=0.0, |
|
num_boundaries_points=0, |
|
): |
|
device = img_tensor.device |
|
control_point_flows = dest_control_point_locations - source_control_point_locations |
|
|
|
batch_size, image_height, image_width = img_tensor.shape |
|
flattened_grid_locations = get_flat_grid_locations( |
|
image_height, image_width, device |
|
) |
|
|
|
flattened_flows = interpolate_spline( |
|
dest_control_point_locations, |
|
control_point_flows, |
|
flattened_grid_locations, |
|
interpolation_order, |
|
regularization_weight, |
|
) |
|
|
|
dense_flows = create_dense_flows( |
|
flattened_flows, batch_size, image_height, image_width |
|
) |
|
|
|
warped_image = dense_image_warp(img_tensor, dense_flows) |
|
|
|
return warped_image, dense_flows |
|
|
|
|
|
def get_grid_locations(image_height, image_width, device): |
|
y_range = torch.linspace(0, image_height - 1, image_height, device=device) |
|
x_range = torch.linspace(0, image_width - 1, image_width, device=device) |
|
y_grid, x_grid = torch.meshgrid(y_range, x_range) |
|
return torch.stack((y_grid, x_grid), -1) |
|
|
|
|
|
def flatten_grid_locations(grid_locations, image_height, image_width): |
|
return torch.reshape(grid_locations, [image_height * image_width, 2]) |
|
|
|
|
|
def get_flat_grid_locations(image_height, image_width, device): |
|
y_range = torch.linspace(0, image_height - 1, image_height, device=device) |
|
x_range = torch.linspace(0, image_width - 1, image_width, device=device) |
|
y_grid, x_grid = torch.meshgrid(y_range, x_range) |
|
return torch.stack((y_grid, x_grid), -1).reshape([image_height * image_width, 2]) |
|
|
|
|
|
def create_dense_flows(flattened_flows, batch_size, image_height, image_width): |
|
|
|
return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) |
|
|
|
|
|
def interpolate_spline( |
|
train_points, |
|
train_values, |
|
query_points, |
|
order, |
|
regularization_weight=0.0, |
|
): |
|
|
|
w, v = solve_interpolation(train_points, train_values, order, regularization_weight) |
|
|
|
query_values = apply_interpolation(query_points, train_points, w, v, order) |
|
|
|
return query_values |
|
|
|
|
|
def solve_interpolation(train_points, train_values, order, regularization_weight): |
|
device = train_points.device |
|
b, n, d = train_points.shape |
|
k = train_values.shape[-1] |
|
|
|
c = train_points |
|
f = train_values.float() |
|
|
|
matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0) |
|
|
|
|
|
ones = torch.ones(1, dtype=train_points.dtype, device=device).view([-1, 1, 1]) |
|
matrix_b = torch.cat((c, ones), 2).float() |
|
|
|
|
|
left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1) |
|
|
|
num_b_cols = matrix_b.shape[2] |
|
|
|
|
|
|
|
|
|
|
|
lhs_zeros = torch.randn((b, num_b_cols, num_b_cols), device=device) / 1e10 |
|
right_block = torch.cat((matrix_b, lhs_zeros), 1) |
|
lhs = torch.cat((left_block, right_block), 2) |
|
|
|
rhs_zeros = torch.zeros( |
|
(b, d + 1, k), dtype=train_points.dtype, device=device |
|
).float() |
|
rhs = torch.cat((f, rhs_zeros), 1) |
|
|
|
|
|
X, LU = torch.gesv(rhs, lhs) |
|
w = X[:, :n, :] |
|
v = X[:, n:, :] |
|
|
|
return w, v |
|
|
|
|
|
def cross_squared_distance_matrix(x, y): |
|
"""Pairwise squared distance between two (batch) matrices' rows (2nd dim). |
|
|
|
Computes the pairwise distances between rows of x and rows of y |
|
Args: |
|
x: [batch_size, n, d] float `Tensor` |
|
y: [batch_size, m, d] float `Tensor` |
|
Returns: |
|
squared_dists: [batch_size, n, m] float `Tensor`, where |
|
squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 |
|
""" |
|
x_norm_squared = torch.sum(torch.mul(x, x)) |
|
y_norm_squared = torch.sum(torch.mul(y, y)) |
|
|
|
x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0, 1)) |
|
|
|
|
|
squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared |
|
|
|
return squared_dists.float() |
|
|
|
|
|
def phi(r, order): |
|
"""Coordinate-wise nonlinearity used to define the order of the interpolation. |
|
|
|
See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. |
|
Args: |
|
r: input op |
|
order: interpolation order |
|
Returns: |
|
phi_k evaluated coordinate-wise on r, for k = r |
|
""" |
|
EPSILON = torch.tensor(1e-10, device=r.device) |
|
|
|
|
|
if order == 1: |
|
r = torch.max(r, EPSILON) |
|
r = torch.sqrt(r) |
|
return r |
|
elif order == 2: |
|
return 0.5 * r * torch.log(torch.max(r, EPSILON)) |
|
elif order == 4: |
|
return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON)) |
|
elif order % 2 == 0: |
|
r = torch.max(r, EPSILON) |
|
return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r) |
|
else: |
|
r = torch.max(r, EPSILON) |
|
return torch.pow(r, 0.5 * order) |
|
|
|
|
|
def apply_interpolation(query_points, train_points, w, v, order): |
|
"""Apply polyharmonic interpolation model to data. |
|
|
|
Notes: |
|
Given coefficients w and v for the interpolation model, we evaluate |
|
interpolated function values at query_points. |
|
|
|
Args: |
|
query_points: `[b, m, d]` x values to evaluate the interpolation at |
|
train_points: `[b, n, d]` x values that act as the interpolation centers |
|
( the c variables in the wikipedia article) |
|
w: `[b, n, k]` weights on each interpolation center |
|
v: `[b, d, k]` weights on each input dimension |
|
order: order of the interpolation |
|
|
|
Returns: |
|
Polyharmonic interpolation evaluated at points defined in query_points. |
|
""" |
|
query_points = query_points.unsqueeze(0) |
|
|
|
pairwise_dists = cross_squared_distance_matrix( |
|
query_points.float(), train_points.float() |
|
) |
|
phi_pairwise_dists = phi(pairwise_dists, order) |
|
|
|
rbf_term = torch.matmul(phi_pairwise_dists, w) |
|
|
|
|
|
|
|
ones = torch.ones_like(query_points[..., :1]) |
|
query_points_pad = torch.cat((query_points, ones), 2).float() |
|
linear_term = torch.matmul(query_points_pad, v) |
|
|
|
return rbf_term + linear_term |
|
|
|
|
|
def dense_image_warp(image, flow): |
|
"""Image warping using per-pixel flow vectors. |
|
|
|
Apply a non-linear warp to the image, where the warp is specified by a dense |
|
flow field of offset vectors that define the correspondences of pixel values |
|
in the output image back to locations in the source image. Specifically, the |
|
pixel value at output[b, j, i, c] is |
|
images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. |
|
The locations specified by this formula do not necessarily map to an int |
|
index. Therefore, the pixel value is obtained by bilinear |
|
interpolation of the 4 nearest pixels around |
|
(b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside |
|
of the image, we use the nearest pixel values at the image boundary. |
|
Args: |
|
image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. |
|
flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. |
|
name: A name for the operation (optional). |
|
Note that image and flow can be of type tf.half, tf.float32, or tf.float64, |
|
and do not necessarily have to be the same type. |
|
Returns: |
|
A 4-D float `Tensor` with shape`[batch, height, width, channels]` |
|
and same type as input image. |
|
Raises: |
|
ValueError: if height < 2 or width < 2 or the inputs have the wrong number |
|
of dimensions. |
|
""" |
|
image = image.unsqueeze(3) |
|
batch_size, height, width, channels = image.shape |
|
device = image.device |
|
|
|
|
|
|
|
grid_x, grid_y = torch.meshgrid( |
|
torch.arange(width, device=device), torch.arange(height, device=device) |
|
) |
|
|
|
stacked_grid = torch.stack((grid_y, grid_x), dim=2).float() |
|
|
|
batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2) |
|
|
|
query_points_on_grid = batched_grid - flow |
|
query_points_flattened = torch.reshape( |
|
query_points_on_grid, [batch_size, height * width, 2] |
|
) |
|
|
|
|
|
interpolated = interpolate_bilinear(image, query_points_flattened) |
|
interpolated = torch.reshape(interpolated, [batch_size, height, width, channels]) |
|
return interpolated |
|
|
|
|
|
def interpolate_bilinear( |
|
grid, query_points, name="interpolate_bilinear", indexing="ij" |
|
): |
|
"""Similar to Matlab's interp2 function. |
|
|
|
Notes: |
|
Finds values for query points on a grid using bilinear interpolation. |
|
|
|
Args: |
|
grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. |
|
query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. |
|
name: a name for the operation (optional). |
|
indexing: whether the query points are specified as row and column (ij), |
|
or Cartesian coordinates (xy). |
|
|
|
Returns: |
|
values: a 3-D `Tensor` with shape `[batch, N, channels]` |
|
|
|
Raises: |
|
ValueError: if the indexing mode is invalid, or if the shape of the inputs |
|
invalid. |
|
""" |
|
if indexing != "ij" and indexing != "xy": |
|
raise ValueError("Indexing mode must be 'ij' or 'xy'") |
|
|
|
shape = grid.shape |
|
if len(shape) != 4: |
|
msg = "Grid must be 4 dimensional. Received size: " |
|
raise ValueError(msg + str(grid.shape)) |
|
|
|
batch_size, height, width, channels = grid.shape |
|
|
|
shape = [batch_size, height, width, channels] |
|
query_type = query_points.dtype |
|
grid_type = grid.dtype |
|
grid_device = grid.device |
|
|
|
num_queries = query_points.shape[1] |
|
|
|
alphas = [] |
|
floors = [] |
|
ceils = [] |
|
index_order = [0, 1] if indexing == "ij" else [1, 0] |
|
unstacked_query_points = query_points.unbind(2) |
|
|
|
for dim in index_order: |
|
queries = unstacked_query_points[dim] |
|
|
|
size_in_indexing_dimension = shape[dim + 1] |
|
|
|
|
|
|
|
max_floor = torch.tensor( |
|
size_in_indexing_dimension - 2, dtype=query_type, device=grid_device |
|
) |
|
min_floor = torch.tensor(0.0, dtype=query_type, device=grid_device) |
|
maxx = torch.max(min_floor, torch.floor(queries)) |
|
floor = torch.min(maxx, max_floor) |
|
int_floor = floor.long() |
|
floors.append(int_floor) |
|
ceil = int_floor + 1 |
|
ceils.append(ceil) |
|
|
|
|
|
|
|
|
|
alpha = torch.tensor((queries - floor), dtype=grid_type, device=grid_device) |
|
min_alpha = torch.tensor(0.0, dtype=grid_type, device=grid_device) |
|
max_alpha = torch.tensor(1.0, dtype=grid_type, device=grid_device) |
|
alpha = torch.min(torch.max(min_alpha, alpha), max_alpha) |
|
|
|
|
|
|
|
alpha = torch.unsqueeze(alpha, 2) |
|
alphas.append(alpha) |
|
|
|
flattened_grid = torch.reshape(grid, [batch_size * height * width, channels]) |
|
batch_offsets = torch.reshape( |
|
torch.arange(batch_size, device=grid_device) * height * width, [batch_size, 1] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def gather(y_coords, x_coords, name): |
|
linear_coordinates = batch_offsets + y_coords * width + x_coords |
|
gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates) |
|
return torch.reshape(gathered_values, [batch_size, num_queries, channels]) |
|
|
|
|
|
top_left = gather(floors[0], floors[1], "top_left") |
|
top_right = gather(floors[0], ceils[1], "top_right") |
|
bottom_left = gather(ceils[0], floors[1], "bottom_left") |
|
bottom_right = gather(ceils[0], ceils[1], "bottom_right") |
|
|
|
interp_top = alphas[1] * (top_right - top_left) + top_left |
|
interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left |
|
interp = alphas[0] * (interp_bottom - interp_top) + interp_top |
|
|
|
return interp |
|
|