Create model.
Browse files- tea_model.py +63 -0
tea_model.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
# VAE Decoder but with half-size output.
|
5 |
+
# The last upsample is not there.
|
6 |
+
|
7 |
+
###
|
8 |
+
# Code from madebyollin/taesd
|
9 |
+
|
10 |
+
class Recon(nn.Module):
|
11 |
+
def __init__(self, ch_in, ch_out):
|
12 |
+
super().__init__()
|
13 |
+
self.long = nn.Sequential(
|
14 |
+
nn.Conv2d(ch_in, ch_out, 3, padding=1),
|
15 |
+
nn.ReLU(),
|
16 |
+
nn.Conv2d(ch_out, ch_out, 3, padding=1),
|
17 |
+
nn.ReLU(),
|
18 |
+
nn.Conv2d(ch_out, ch_out, 3, padding=1)
|
19 |
+
)
|
20 |
+
if ch_in != ch_out:
|
21 |
+
self.short = nn.Conv2d(ch_in, ch_out, 1, bias=False)
|
22 |
+
else:
|
23 |
+
# The one without identity, a placeholder.
|
24 |
+
self.short = nn.Identity()
|
25 |
+
self.fuse = nn.ReLU()
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.fuse(self.long(x) + self.short(x))
|
29 |
+
|
30 |
+
class TeaEncoder(nn.Module):
|
31 |
+
def __init__(self, ch_in):
|
32 |
+
super().__init__()
|
33 |
+
self.block_in = nn.Sequential(
|
34 |
+
nn.Conv2d(ch_in, 64, 3, padding=1),
|
35 |
+
nn.ReLU()
|
36 |
+
)
|
37 |
+
self.middle = nn.Sequential(
|
38 |
+
*[Recon(64, 64) for _ in range(3)],
|
39 |
+
# Opposite of stride=2
|
40 |
+
nn.Upsample(scale_factor=2),
|
41 |
+
# It leads to a simpler model with fewer parameters.
|
42 |
+
# The output of the previous layers matches the number of channels specified in this line.
|
43 |
+
# The input to this layer is already well-represented by the feature maps from the previous layers,
|
44 |
+
# the bias may not add significant value.
|
45 |
+
nn.Conv2d(64, 64, 3, padding=1, bias=False),
|
46 |
+
# Final upscale to 1/2 size of the image.
|
47 |
+
*[Recon(64, 64) for _ in range(3)],
|
48 |
+
nn.Upsample(scale_factor=2),
|
49 |
+
nn.Conv2d(64, 64, 3, padding=1, bias=False),
|
50 |
+
)
|
51 |
+
self.block_out = nn.Sequential(
|
52 |
+
Recon(64, 64),
|
53 |
+
# Convert to RGB, regardless of the latent channels.
|
54 |
+
nn.Conv2d(64, 3, 3, padding=1),
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
# Clamp the input values to a specific range, between -3 and 3.
|
60 |
+
clamped = torch.tanh(x / 1)
|
61 |
+
cooked = self.middle(self.block_in(clamped))
|
62 |
+
|
63 |
+
return self.block_out(cooked)
|