File size: 2,236 Bytes
98e2c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch    
import torch.nn as nn

from .utils import extract_patches, combine_patches, efficient_cdist, get_NNs_Dists

def make_criteria(conf):
    if conf['type'] == 'PatchCoherentLoss':
        return PatchCoherentLoss(conf['patch_size'], stride=conf['stride'], loop=conf['loop'], coherent_alpha=conf['coherent_alpha'])
    elif conf['type'] == 'SWDLoss':
        raise NotImplementedError('SWDLoss is not implemented')
    else:
        raise ValueError('Invalid criteria: {}'.format(conf['criteria']))

class PatchCoherentLoss(torch.nn.Module):
    def __init__(self, patch_size=7, stride=1, loop=False, coherent_alpha=None, cache=False):
        super(PatchCoherentLoss, self).__init__()
        self.patch_size = patch_size
        self.stride = stride
        self.loop = loop
        self.coherent_alpha = coherent_alpha
        assert self.stride == 1, "Only support stride of 1"
        # assert self.patch_size % 2 == 1, "Only support odd patch size"
        self.cache = cache
        if cache:
            self.cached_data = None

    def forward(self, X, Ys, dist_wrapper=None, ext=None, return_blended_results=False):
        """For each patch in input X find its NN in target Y and sum the their distances"""
        assert X.shape[0] == 1, "Only support batch size of 1"
        dist_fn = lambda X, Y: dist_wrapper(efficient_cdist, X, Y) if dist_wrapper is not None else efficient_cdist(X, Y)

        x_patches = extract_patches(X, self.patch_size, self.stride, loop=self.loop)

        if not self.cache or self.cached_data is None:
            y_patches = []
            for y in Ys:
                y_patches += [extract_patches(y, self.patch_size, self.stride, loop=False)]
            y_patches = torch.cat(y_patches, dim=1)
            self.cached_data = y_patches
        else:
            y_patches = self.cached_data
        
        nnf, dist = get_NNs_Dists(dist_fn, x_patches.squeeze(0), y_patches.squeeze(0), self.coherent_alpha)

        if return_blended_results:
            return combine_patches(X.shape, y_patches[:, nnf, :], self.patch_size, self.stride, loop=self.loop), dist.mean()
        else:
            return dist.mean()
    
    def clean_cache(self):
        self.cached_data = None