Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
import torch | |
from torch.nn import functional as F | |
import numpy as np | |
try: | |
import pydensecrf.densecrf as dcrf | |
from pydensecrf.utils import ( | |
unary_from_softmax, | |
unary_from_labels, | |
create_pairwise_bilateral, | |
create_pairwise_gaussian, | |
) | |
except: | |
dcrf = None | |
def dense_crf_post_process( | |
logits, | |
image, | |
n_labels=None, | |
max_iters=5, | |
pos_xy_std=(3, 3), | |
pos_w=3, | |
bi_xy_std=(80, 80), | |
bi_rgb_std=(13, 13, 13), | |
bi_w=10, | |
): | |
""" | |
logits : [C,H,W] | |
image : [3,H,W] | |
""" | |
if dcrf is None: | |
raise FileNotFoundError( | |
"pydensecrf is required to perform dense crf inference." | |
) | |
if isinstance(logits, torch.Tensor): | |
logits = F.softmax(logits, dim=0).detach().cpu().numpy() | |
U = unary_from_softmax(logits) | |
n_labels = logits.shape[0] | |
elif logits.ndim == 3: | |
U = unary_from_softmax(logits) | |
n_labels = logits.shape[0] | |
else: | |
assert n_labels is not None | |
U = unary_from_labels(logits, n_labels, zero_unsure=False) | |
d = dcrf.DenseCRF2D(image.shape[1], image.shape[0], n_labels) | |
d.setUnaryEnergy(U) | |
# This adds the color-independent term, features are the locations only. | |
d.addPairwiseGaussian( | |
sxy=pos_xy_std, | |
compat=pos_w, | |
kernel=dcrf.DIAG_KERNEL, | |
normalization=dcrf.NORMALIZE_SYMMETRIC, | |
) | |
# This adds the color-dependent term, i.e. features are (x,y,r,g,b). | |
d.addPairwiseBilateral( | |
sxy=bi_xy_std, | |
srgb=bi_rgb_std, | |
rgbim=image, | |
compat=bi_w, | |
kernel=dcrf.DIAG_KERNEL, | |
normalization=dcrf.NORMALIZE_SYMMETRIC, | |
) | |
# Run five inference steps. | |
logits = d.inference(max_iters) | |
logits = np.asarray(logits).reshape((n_labels, image.shape[0], image.shape[1])) | |
return torch.from_numpy(logits) | |