File size: 6,482 Bytes
0a88b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Modified by Zexin He
# The modifications are subject to the same license as the original.


"""

The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.

Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)

"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class LearnedVariance(nn.Module):
    def __init__(self, init_val):
        super(LearnedVariance, self).__init__()
        self.register_parameter("_inv_std", nn.Parameter(torch.tensor(init_val)))

    @property
    def inv_std(self):
        val = torch.exp(self._inv_std * 10.0)
        return val

    def forward(self, x):
        return torch.ones_like(x) * self.inv_std.clamp(1.0e-6, 1.0e6)


class MipRayMarcher2(nn.Module):
    def __init__(self, activation_factory):
        super().__init__()
        self.activation_factory = activation_factory
        self.variance = LearnedVariance(0.3)
        self.cos_anneal_ratio = 1.0
    def get_alpha(self, sdf, normal, dirs, dists):
        # sdf: [N 1]  normal: [N 3]   dirs: [N 3] dists: [N 1]
        # import ipdb; ipdb.set_trace()

        inv_std = self.variance(sdf)

        true_cos = (dirs * normal).sum(-1, keepdim=True)
        # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
        # the cos value "not dead" at the beginning training iterations, for better convergence.
        iter_cos = -(
            F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio)
            + F.relu(-true_cos) * self.cos_anneal_ratio
        )  # always non-positive

        # Estimate signed distances at section points
        estimated_next_sdf = sdf + iter_cos * dists * 0.5
        estimated_prev_sdf = sdf - iter_cos * dists * 0.5

        prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_std)
        next_cdf = torch.sigmoid(estimated_next_sdf * inv_std)

        p = prev_cdf - next_cdf
        c = prev_cdf

        alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0)
        return alpha
    
    def run_forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None):
        # depths: [B N_ray*N_sample 1] 
        # sdfs: [B, N_ray, N_sample 1]
        # import ipdb; ipdb.set_trace() 

        deltas = depths[:, :, 1:] - depths[:, :, :-1]
        colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
        sdfs_mid = (sdfs[:, :, :-1] + sdfs[:, :, 1:]) / 2
        depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
        normals_mid = (normals[:, :, :-1] + normals[:, :, 1:]) / 2

        # zhaohx add for normal :
        real_normals_mid = (real_normals[:, :, :-1] + real_normals[:, :, 1:]) / 2

        # # using factory mode for better usability
        # densities_mid = self.activation_factory(rendering_options)(densities_mid)

        # density_delta = densities_mid * deltas

        # alpha = 1 - torch.exp(-density_delta)

        # import ipdb; ipdb.set_trace()
        dirs = ray_directions.unsqueeze(2).expand(-1, -1, sdfs_mid.shape[-2], -1)
        B, N_ray, N_sample, _ = sdfs_mid.shape
        alpha = self.get_alpha(sdfs_mid.reshape(-1, 1), normals_mid.reshape(-1, 3), dirs.reshape(-1, 3), deltas.reshape(-1, 1))
        alpha = alpha.reshape(B, N_ray, N_sample, -1)
        
        alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
        weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]

        composite_rgb = torch.sum(weights * colors_mid, -2)
        weight_total = weights.sum(2)
        composite_depth = torch.sum(weights * depths_mid, -2) / weight_total

        # clip the composite to min/max range of depths
        composite_depth = torch.nan_to_num(composite_depth, float('inf'))
        composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
        # import pdb; pdb.set_trace()

        # zhaohx add for normal :
        composite_normal = torch.sum(weights * real_normals_mid, -2) / weight_total
        composite_normal = torch.nan_to_num(composite_normal, float('inf'))
        composite_normal = torch.clamp(composite_normal, torch.min(real_normals), torch.max(real_normals))

        if rendering_options.get('white_back', False):
            # composite_rgb = composite_rgb + 1 - weight_total
            # weight_total[weight_total < 0.5] = 0
            # composite_rgb = composite_rgb * weight_total + 1 - weight_total
            # now is this
            if bgcolor is None:
                composite_rgb = composite_rgb + 1 - weight_total
                # composite_rgb = composite_rgb * weight_total + 1 - weight_total
            else:
                # import pdb; pdb.set_trace()
                bgcolor = bgcolor.permute(0, 2, 3, 1).contiguous().view(composite_rgb.shape[0], -1, composite_rgb.shape[-1])
                composite_rgb = composite_rgb + (1 - weight_total) * bgcolor
                # composite_rgb = composite_rgb * weight_total + (1 - weight_total) * bgcolor
            # composite_rgb = composite_rgb
            # print('new white_back')

        # rendered value scale is 0-1, comment out original mipnerf scaling
        # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)

        return composite_rgb, composite_depth, weights, composite_normal


    def forward(self,  colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None):
        composite_rgb, composite_depth, weights, composite_normal = self.run_forward(colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor, real_normals)

        return composite_rgb, composite_depth, weights, composite_normal