File size: 2,714 Bytes
0a88b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import nvdiffrast.torch as dr
import torch

from ...utils.typing import *


class NVDiffRasterizerContext:
    def __init__(self, context_type: str, device: torch.device) -> None:
        self.device = device
        self.ctx = self.initialize_context(context_type, device)

    def initialize_context(

        self, context_type: str, device: torch.device

    ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
        if context_type == "gl":
            return dr.RasterizeGLContext(device=device)
        elif context_type == "cuda":
            return dr.RasterizeCudaContext(device=device)
        else:
            raise ValueError(f"Unknown rasterizer context type: {context_type}")

    def vertex_transform(

        self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]

    ) -> Float[Tensor, "B Nv 4"]:
        verts_homo = torch.cat(
            [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
        )
        return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))

    def rasterize(

        self,

        pos: Float[Tensor, "B Nv 4"],

        tri: Integer[Tensor, "Nf 3"],

        resolution: Union[int, Tuple[int, int]],

    ):
        # rasterize in instance mode (single topology)
        return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)

    def rasterize_one(

        self,

        pos: Float[Tensor, "Nv 4"],

        tri: Integer[Tensor, "Nf 3"],

        resolution: Union[int, Tuple[int, int]],

    ):
        # rasterize one single mesh under a single viewpoint
        rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
        return rast[0], rast_db[0]

    def antialias(

        self,

        color: Float[Tensor, "B H W C"],

        rast: Float[Tensor, "B H W 4"],

        pos: Float[Tensor, "B Nv 4"],

        tri: Integer[Tensor, "Nf 3"],

    ) -> Float[Tensor, "B H W C"]:
        return dr.antialias(color.float(), rast, pos.float(), tri.int())

    def interpolate(

        self,

        attr: Float[Tensor, "B Nv C"],

        rast: Float[Tensor, "B H W 4"],

        tri: Integer[Tensor, "Nf 3"],

        rast_db=None,

        diff_attrs=None,

    ) -> Float[Tensor, "B H W C"]:
        return dr.interpolate(
            attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
        )

    def interpolate_one(

        self,

        attr: Float[Tensor, "Nv C"],

        rast: Float[Tensor, "B H W 4"],

        tri: Integer[Tensor, "Nf 3"],

        rast_db=None,

        diff_attrs=None,

    ) -> Float[Tensor, "B H W C"]:
        return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)