File size: 5,313 Bytes
fbf122b
dee645c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbf122b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import numpy as np
import random
import torch
import math


def rotate_axis(x, add_angle=0, axis=1):  # TODO Replace with a rotation matrix   # But this is more fun
    axes = list(range(3))
    axes.remove(axis)
    ax1, ax2 = axes
    angle = torch.atan2(x[..., ax1], x[..., ax2])
    if isinstance(add_angle, torch.Tensor):
        while add_angle.ndim < angle.ndim:
            add_angle = add_angle.unsqueeze(-1)
    angle = angle + add_angle
    dist = x.norm(dim=-1)
    t = []
    _, t = zip(*sorted([
        (axis, x[..., axis]),
        (ax1, torch.sin(angle) * dist),
        (ax2, torch.cos(angle) * dist),
    ]))
    return torch.stack(t, dim=-1)


noise_level = 0.5


# stolen from https://gist.github.com/ac1b097753f217c5c11bc2ff396e0a57
# ported from https://github.com/pvigier/perlin-numpy/blob/master/perlin2d.py
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
    delta = (res[0] / shape[0], res[1] / shape[1])
    d = (shape[0] // res[0], shape[1] // res[1])

    grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
    angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
    gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)

    tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
                                                                                                              0).repeat_interleave(
        d[1], 1)
    dot = lambda grad, shift: (
                torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
                            dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)

    n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
    n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
    n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
    n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
    t = fade(grid[:shape[0], :shape[1]])
    return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])


def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
    noise = torch.zeros(shape)
    frequency = 1
    amplitude = 1
    for _ in range(octaves):
        noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
        frequency *= 2
        amplitude *= persistence
    noise *= random.random() - noise_level  # haha
    noise += random.random() - noise_level  # haha x2
    return noise


def load_clip(model_name="ViT-B/16", device="cuda:0" if torch.cuda.is_available() else "cpu"):
    import clip
    model, preprocess = clip.load(model_name, device=device, jit=False)
    if len(preprocess.transforms) > 4:
        preprocess.transforms = preprocess.transforms[-1:]
    return model, preprocess


# http://blog.andreaskahler.com/2009/06/creating-icosphere-mesh-in-code.html
def ico():
    phi = (1 + 5 ** 0.5) / 2
    return (
        np.array([
            [-1, phi, 0],
            [1, phi, 0],
            [-1, -phi, 0],
            [1, -phi, 0],
            [0, -1, phi],
            [0, 1, phi],
            [0, -1, -phi],
            [0, 1, -phi],
            [phi, 0, -1],
            [phi, 0, 1],
            [-phi, 0, -1],
            [-phi, 0, 1]
        ]) / phi,
        [
            [0, 11, 5],
            [0, 5, 1],
            [0, 1, 7],
            [0, 7, 10],
            [0, 10, 11],
            [1, 5, 9],
            [5, 11, 4],
            [11, 10, 2],
            [10, 7, 6],
            [7, 1, 8],
            [3, 9, 4],
            [3, 4, 2],
            [3, 2, 6],
            [3, 6, 8],
            [3, 8, 9],
            [4, 9, 5],
            [2, 4, 11],
            [6, 2, 10],
            [8, 6, 7],
            [9, 8, 1]
        ]
    )


def ico_at(xyz=np.array([0, 0, 0]), radius=1.0, i=0):
    vert, idx = ico()
    return vert * radius + xyz, [[y + i for y in x] for x in idx]


def save_ply(points, out_path):
    with torch.inference_mode():
        vert_pos, vert_col, vert_rad, vert_opa = (x.detach().cpu().numpy() for x in points)
      
    verts = []
    faces = []
    for xyz, radius, (r, g, b), a in zip(vert_pos, vert_rad, vert_col, vert_opa):
        v, i = ico_at(xyz, radius, len(verts))
        for x, y, z in v:
            verts.append((x, y, z, int(r * 255), int(g * 255), int(b * 255), int(a * 255)))
        faces += i

    with open(out_path, "w") as out_file:
        out_file.write("ply\n")
        out_file.write("format ascii 1.0\n")
        out_file.write(f"element vertex {len(verts)}\n")
        out_file.write("property float x\n")
        out_file.write("property float y\n")
        out_file.write("property float z\n")
        out_file.write("property uchar red\n")
        out_file.write("property uchar green\n")
        out_file.write("property uchar blue\n")
        out_file.write("property uchar alpha\n")
        out_file.write(f"element face {len(faces)}\n")
        out_file.write("property list uchar int vertex_index\n")
        out_file.write("end_header\n")
        for v in verts:
            out_file.write(" ".join(map(str, v)) + "\n")
        for f in faces:
            out_file.write(" ".join(map(str, [len(f)] + f)) + "\n")