|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def searchsorted(a, v): |
|
"""Find indices where v should be inserted into a to maintain order. |
|
|
|
Args: |
|
a: tensor, the sorted reference points that we are scanning to see where v |
|
should lie. |
|
v: tensor, the query points that we are pretending to insert into a. Does |
|
not need to be sorted. All but the last dimensions should match or expand |
|
to those of a, the last dimension can differ. |
|
|
|
Returns: |
|
(idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the |
|
range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or |
|
last index of a. |
|
""" |
|
i = torch.arange(a.shape[-1], device=a.device) |
|
v_ge_a = v[..., None, :] >= a[..., :, None] |
|
idx_lo = torch.max(torch.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2).values |
|
idx_hi = torch.min(torch.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2).values |
|
return idx_lo, idx_hi |
|
|
|
|
|
def query(tq, t, y, outside_value=0): |
|
"""Look up the values of the step function (t, y) at locations tq.""" |
|
idx_lo, idx_hi = searchsorted(t, tq) |
|
yq = torch.where(idx_lo == idx_hi, torch.full_like(idx_hi, outside_value), |
|
torch.take_along_dim(y, idx_lo, dim=-1)) |
|
return yq |
|
|
|
|
|
def inner_outer(t0, t1, y1): |
|
"""Construct inner and outer measures on (t1, y1) for t0.""" |
|
cy1 = torch.cat([torch.zeros_like(y1[..., :1]), |
|
torch.cumsum(y1, dim=-1)], |
|
dim=-1) |
|
idx_lo, idx_hi = searchsorted(t1, t0) |
|
|
|
cy1_lo = torch.take_along_dim(cy1, idx_lo, dim=-1) |
|
cy1_hi = torch.take_along_dim(cy1, idx_hi, dim=-1) |
|
|
|
y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1] |
|
y0_inner = torch.where(idx_hi[..., :-1] <= idx_lo[..., 1:], |
|
cy1_lo[..., 1:] - cy1_hi[..., :-1], torch.zeros_like(idx_lo[..., 1:])) |
|
return y0_inner, y0_outer |
|
|
|
|
|
def lossfun_outer(t, w, t_env, w_env): |
|
"""The proposal weight should be an upper envelope on the nerf weight.""" |
|
eps = torch.finfo(t.dtype).eps |
|
|
|
|
|
_, w_outer = inner_outer(t, t_env, w_env) |
|
|
|
|
|
|
|
return (w - w_outer).clamp_min(0) ** 2 / (w + eps) |
|
|
|
|
|
def weight_to_pdf(t, w): |
|
"""Turn a vector of weights that sums to 1 into a PDF that integrates to 1.""" |
|
eps = torch.finfo(t.dtype).eps |
|
return w / (t[..., 1:] - t[..., :-1]).clamp_min(eps) |
|
|
|
|
|
def pdf_to_weight(t, p): |
|
"""Turn a PDF that integrates to 1 into a vector of weights that sums to 1.""" |
|
return p * (t[..., 1:] - t[..., :-1]) |
|
|
|
|
|
def max_dilate(t, w, dilation, domain=(-torch.inf, torch.inf)): |
|
"""Dilate (via max-pooling) a non-negative step function.""" |
|
t0 = t[..., :-1] - dilation |
|
t1 = t[..., 1:] + dilation |
|
t_dilate, _ = torch.sort(torch.cat([t, t0, t1], dim=-1), dim=-1) |
|
t_dilate = torch.clip(t_dilate, *domain) |
|
w_dilate = torch.max( |
|
torch.where( |
|
(t0[..., None, :] <= t_dilate[..., None]) |
|
& (t1[..., None, :] > t_dilate[..., None]), |
|
w[..., None, :], |
|
torch.zeros_like(w[..., None, :]), |
|
), dim=-1).values[..., :-1] |
|
return t_dilate, w_dilate |
|
|
|
|
|
def max_dilate_weights(t, |
|
w, |
|
dilation, |
|
domain=(-torch.inf, torch.inf), |
|
renormalize=False): |
|
"""Dilate (via max-pooling) a set of weights.""" |
|
eps = torch.finfo(w.dtype).eps |
|
|
|
|
|
p = weight_to_pdf(t, w) |
|
t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain) |
|
w_dilate = pdf_to_weight(t_dilate, p_dilate) |
|
if renormalize: |
|
w_dilate /= torch.sum(w_dilate, dim=-1, keepdim=True).clamp_min(eps) |
|
return t_dilate, w_dilate |
|
|
|
|
|
def integrate_weights(w): |
|
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1. |
|
|
|
The output's size on the last dimension is one greater than that of the input, |
|
because we're computing the integral corresponding to the endpoints of a step |
|
function, not the integral of the interior/bin values. |
|
|
|
Args: |
|
w: Tensor, which will be integrated along the last axis. This is assumed to |
|
sum to 1 along the last axis, and this function will (silently) break if |
|
that is not the case. |
|
|
|
Returns: |
|
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 |
|
""" |
|
cw = torch.cumsum(w[..., :-1], dim=-1).clamp_max(1) |
|
shape = cw.shape[:-1] + (1,) |
|
|
|
cw0 = torch.cat([torch.zeros(shape, device=cw.device), cw, |
|
torch.ones(shape, device=cw.device)], dim=-1) |
|
return cw0 |
|
|
|
|
|
def integrate_weights_np(w): |
|
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1. |
|
|
|
The output's size on the last dimension is one greater than that of the input, |
|
because we're computing the integral corresponding to the endpoints of a step |
|
function, not the integral of the interior/bin values. |
|
|
|
Args: |
|
w: Tensor, which will be integrated along the last axis. This is assumed to |
|
sum to 1 along the last axis, and this function will (silently) break if |
|
that is not the case. |
|
|
|
Returns: |
|
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 |
|
""" |
|
cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1)) |
|
shape = cw.shape[:-1] + (1,) |
|
|
|
cw0 = np.concatenate([np.zeros(shape), cw, |
|
np.ones(shape)], axis=-1) |
|
return cw0 |
|
|
|
|
|
def invert_cdf(u, t, w_logits): |
|
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" |
|
|
|
w = torch.softmax(w_logits, dim=-1) |
|
cw = integrate_weights(w) |
|
|
|
t_new = math.sorted_interp(u, cw, t) |
|
return t_new |
|
|
|
|
|
def invert_cdf_np(u, t, w_logits): |
|
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" |
|
|
|
w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True) |
|
cw = integrate_weights_np(w) |
|
|
|
interp_fn = np.interp |
|
t_new = interp_fn(u, cw, t) |
|
return t_new |
|
|
|
|
|
def sample(rand, |
|
t, |
|
w_logits, |
|
num_samples, |
|
single_jitter=False, |
|
deterministic_center=False): |
|
"""Piecewise-Constant PDF sampling from a step function. |
|
|
|
Args: |
|
rand: random number generator (or None for `linspace` sampling). |
|
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) |
|
w_logits: [..., num_bins], logits corresponding to bin weights |
|
num_samples: int, the number of samples. |
|
single_jitter: bool, if True, jitter every sample along each ray by the same |
|
amount in the inverse CDF. Otherwise, jitter each sample independently. |
|
deterministic_center: bool, if False, when `rand` is None return samples that |
|
linspace the entire PDF. If True, skip the front and back of the linspace |
|
so that the centers of each PDF interval are returned. |
|
|
|
Returns: |
|
t_samples: [batch_size, num_samples]. |
|
""" |
|
eps = torch.finfo(t.dtype).eps |
|
|
|
|
|
device = t.device |
|
|
|
|
|
if not rand: |
|
if deterministic_center: |
|
pad = 1 / (2 * num_samples) |
|
u = torch.linspace(pad, 1. - pad - eps, num_samples, device=device) |
|
else: |
|
u = torch.linspace(0, 1. - eps, num_samples, device=device) |
|
u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,)) |
|
else: |
|
|
|
u_max = eps + (1 - eps) / num_samples |
|
max_jitter = (1 - u_max) / (num_samples - 1) - eps |
|
d = 1 if single_jitter else num_samples |
|
u = torch.linspace(0, 1 - u_max, num_samples, device=device) + \ |
|
torch.rand(t.shape[:-1] + (d,), device=device) * max_jitter |
|
|
|
return invert_cdf(u, t, w_logits) |
|
|
|
|
|
def sample_np(rand, |
|
t, |
|
w_logits, |
|
num_samples, |
|
single_jitter=False, |
|
deterministic_center=False): |
|
""" |
|
numpy version of sample() |
|
""" |
|
eps = np.finfo(np.float32).eps |
|
|
|
|
|
if not rand: |
|
if deterministic_center: |
|
pad = 1 / (2 * num_samples) |
|
u = np.linspace(pad, 1. - pad - eps, num_samples) |
|
else: |
|
u = np.linspace(0, 1. - eps, num_samples) |
|
u = np.broadcast_to(u, t.shape[:-1] + (num_samples,)) |
|
else: |
|
|
|
u_max = eps + (1 - eps) / num_samples |
|
max_jitter = (1 - u_max) / (num_samples - 1) - eps |
|
d = 1 if single_jitter else num_samples |
|
u = np.linspace(0, 1 - u_max, num_samples) + \ |
|
np.random.rand(*t.shape[:-1], d) * max_jitter |
|
|
|
return invert_cdf_np(u, t, w_logits) |
|
|
|
|
|
def sample_intervals(rand, |
|
t, |
|
w_logits, |
|
num_samples, |
|
single_jitter=False, |
|
domain=(-torch.inf, torch.inf)): |
|
"""Sample *intervals* (rather than points) from a step function. |
|
|
|
Args: |
|
rand: random number generator (or None for `linspace` sampling). |
|
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) |
|
w_logits: [..., num_bins], logits corresponding to bin weights |
|
num_samples: int, the number of intervals to sample. |
|
single_jitter: bool, if True, jitter every sample along each ray by the same |
|
amount in the inverse CDF. Otherwise, jitter each sample independently. |
|
domain: (minval, maxval), the range of valid values for `t`. |
|
|
|
Returns: |
|
t_samples: [batch_size, num_samples]. |
|
""" |
|
if num_samples <= 1: |
|
raise ValueError(f'num_samples must be > 1, is {num_samples}.') |
|
|
|
|
|
centers = sample( |
|
rand, |
|
t, |
|
w_logits, |
|
num_samples, |
|
single_jitter, |
|
deterministic_center=True) |
|
|
|
|
|
mid = (centers[..., 1:] + centers[..., :-1]) / 2 |
|
|
|
|
|
|
|
|
|
minval, maxval = domain |
|
first = (2 * centers[..., :1] - mid[..., :1]).clamp_min(minval) |
|
last = (2 * centers[..., -1:] - mid[..., -1:]).clamp_max(maxval) |
|
|
|
t_samples = torch.cat([first, mid, last], dim=-1) |
|
return t_samples |
|
|
|
|
|
def lossfun_distortion(t, w): |
|
"""Compute iint w[i] w[j] |t[i] - t[j]| di dj.""" |
|
|
|
ut = (t[..., 1:] + t[..., :-1]) / 2 |
|
dut = torch.abs(ut[..., :, None] - ut[..., None, :]) |
|
loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) |
|
|
|
|
|
loss_intra = torch.sum(w ** 2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 |
|
|
|
return loss_inter + loss_intra |
|
|
|
|
|
def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi): |
|
"""Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi]).""" |
|
|
|
d_disjoint = torch.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2) |
|
|
|
|
|
d_overlap = (2 * |
|
(torch.minimum(t0_hi, t1_hi) ** 3 - torch.maximum(t0_lo, t1_lo) ** 3) + |
|
3 * (t1_hi * t0_hi * torch.abs(t1_hi - t0_hi) + |
|
t1_lo * t0_lo * torch.abs(t1_lo - t0_lo) + t1_hi * t0_lo * |
|
(t0_lo - t1_hi) + t1_lo * t0_hi * |
|
(t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo)) |
|
|
|
|
|
are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi) |
|
|
|
return torch.where(are_disjoint, d_disjoint, d_overlap) |
|
|
|
|
|
def weighted_percentile(t, w, ps): |
|
"""Compute the weighted percentiles of a step function. w's must sum to 1.""" |
|
cw = integrate_weights(w) |
|
|
|
fn = lambda cw_i, t_i: math.sorted_interp(torch.tensor(ps, device=t.device) / 100, cw_i, t_i) |
|
|
|
cw_mat = cw.reshape([-1, cw.shape[-1]]) |
|
t_mat = t.reshape([-1, t.shape[-1]]) |
|
wprctile_mat = fn(cw_mat, t_mat) |
|
wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),)) |
|
return wprctile |
|
|
|
|
|
def resample(t, tp, vp, use_avg=False): |
|
"""Resample a step function defined by (tp, vp) into intervals t. |
|
|
|
Args: |
|
t: tensor with shape (..., n+1), the endpoints to resample into. |
|
tp: tensor with shape (..., m+1), the endpoints of the step function being |
|
resampled. |
|
vp: tensor with shape (..., m), the values of the step function being |
|
resampled. |
|
use_avg: bool, if False, return the sum of the step function for each |
|
interval in `t`. If True, return the average, weighted by the width of |
|
each interval in `t`. |
|
eps: float, a small value to prevent division by zero when use_avg=True. |
|
|
|
Returns: |
|
v: tensor with shape (..., n), the values of the resampled step function. |
|
""" |
|
eps = torch.finfo(t.dtype).eps |
|
|
|
|
|
if use_avg: |
|
wp = torch.diff(tp, dim=-1) |
|
v_numer = resample(t, tp, vp * wp, use_avg=False) |
|
v_denom = resample(t, tp, wp, use_avg=False) |
|
v = v_numer / v_denom.clamp_min(eps) |
|
return v |
|
|
|
acc = torch.cumsum(vp, dim=-1) |
|
acc0 = torch.cat([torch.zeros(acc.shape[:-1] + (1,), device=acc.device), acc], dim=-1) |
|
acc0_resampled = math.sorted_interp(t, tp, acc0) |
|
v = torch.diff(acc0_resampled, dim=-1) |
|
return v |
|
|
|
|
|
def resample_np(t, tp, vp, use_avg=False): |
|
""" |
|
numpy version of resample |
|
""" |
|
eps = np.finfo(t.dtype).eps |
|
if use_avg: |
|
wp = np.diff(tp, axis=-1) |
|
v_numer = resample_np(t, tp, vp * wp, use_avg=False) |
|
v_denom = resample_np(t, tp, wp, use_avg=False) |
|
v = v_numer / np.maximum(eps, v_denom) |
|
return v |
|
|
|
acc = np.cumsum(vp, axis=-1) |
|
acc0 = np.concatenate([np.zeros(acc.shape[:-1] + (1,)), acc], axis=-1) |
|
acc0_resampled = np.vectorize(np.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0) |
|
v = np.diff(acc0_resampled, axis=-1) |
|
return v |
|
|
|
|
|
def blur_stepfun(x, y, r): |
|
xr, xr_idx = torch.sort(torch.cat([x - r, x + r], dim=-1)) |
|
y1 = (torch.cat([y, torch.zeros_like(y[..., :1])], dim=-1) - |
|
torch.cat([torch.zeros_like(y[..., :1]), y], dim=-1)) / (2 * r) |
|
y2 = torch.cat([y1, -y1], dim=-1).take_along_dim(xr_idx[..., :-1], dim=-1) |
|
yr = torch.cumsum((xr[..., 1:] - xr[..., :-1]) * |
|
torch.cumsum(y2, dim=-1), dim=-1).clamp_min(0) |
|
yr = torch.cat([torch.zeros_like(yr[..., :1]), yr], dim=-1) |
|
return xr, yr |