sgerard commited on
Commit
e2b81de
Β·
1 Parent(s): 752f2e3

Initial commit of working version

Browse files
Files changed (8) hide show
  1. README.md +3 -3
  2. app.py +56 -0
  3. gan_utils.py +31 -0
  4. layers.py +273 -0
  5. models.py +246 -0
  6. requirements.txt +4 -0
  7. text_utils.py +31 -0
  8. utils.py +77 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Illustrated Lyrics Generator
3
- emoji: πŸ’»
4
  colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.0.26
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: Illustrated Lyrics Generator
3
+ emoji: 🎢
4
  colorFrom: indigo
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.0.24
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ from text_utils import wrap_text, compute_text_position
5
+ from gan_utils import load_img_generator, generate_img
6
+ from PIL import ImageFont, ImageDraw
7
+ import torch
8
+
9
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ device = "cpu"
11
+
12
+ text_generator = pipeline('text-generation', model='huggingtweets/bestmusiclyric')
13
+
14
+
15
+ def generate_captioned_img(lyrics_prompt, gan_model):
16
+ gan_image = generate_img(device, gan_model)
17
+
18
+ generated_text = text_generator(lyrics_prompt)[0]["generated_text"]
19
+ wrapped_text = wrap_text(generated_text)
20
+
21
+ text_pos = compute_text_position(wrapped_text)
22
+
23
+ # Source: https://stackoverflow.com/a/16377244
24
+ draw = ImageDraw.Draw(gan_image)
25
+ font = ImageFont.truetype("DejaVuSans.ttf", 64)
26
+ draw.text((10, text_pos), text=wrapped_text, fill_color=(255, 255, 255), font=font, stroke_fill=(0, 0, 0),
27
+ stroke_width=5)
28
+
29
+ return gan_image
30
+
31
+
32
+ iface = gr.Interface(fn=generate_captioned_img, inputs=[gr.Textbox(value="Running with the wolves", label="Lyrics prompt", lines=1),
33
+ gr.Radio(value="aurora",
34
+ choices=["painting", "fauvism-still-life", "aurora",
35
+ "universe", "moongate"],
36
+ label="FastGAN model")
37
+ ],
38
+ outputs="image",
39
+ allow_flagging="never",
40
+ title="Illustrated lyrics generator",
41
+ description="Combines song lyrics generation via the [Best Music Lyric Bot]"
42
+ "(https://huggingface.co/huggingtweets/bestmusiclyric) with an artwork randomly "
43
+ "generated by a [FastGAN model](https://huggingface.co/spaces/huggan/FastGan).\n\n"
44
+ "Text and lyrics are generated independently. "
45
+ "If you can implement this idea with images conditioned on the lyrics,"
46
+ " I'd be very interested in seeing that!πŸ€—\n\n"
47
+ "At the bottom of the page, you can click some example inputs to get you started.",
48
+ examples=[["Hey now", "fauvism-still-life"], ["It's gonna take a lot", "universe"],
49
+ ["Running with the wolves", "aurora"], ["His palms are sweaty", "painting"],
50
+ ["I just met you", "moongate"]]
51
+ )
52
+ iface.launch()
53
+
54
+
55
+ #examples=[["Hey now", "painting"], ["It's gonna take a lot", "universe"], ["So close", "aurora"], ["I just met you", "moongate"],
56
+ # ["His palms are sweaty", "aurora"]])
gan_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code adapted from the following sources:
2
+ # https://huggingface.co/huggan/fastgan-few-shot-fauvism-still-life
3
+ # https://huggingface.co/spaces/huggan/FastGan/
4
+
5
+ import torch
6
+ from PIL import Image
7
+
8
+ from models import Generator
9
+
10
+
11
+ def load_img_generator(model_name_or_path):
12
+ generator = Generator(in_channels=256, out_channels=3)
13
+ generator = generator.from_pretrained(model_name_or_path, in_channels=256, out_channels=3)
14
+ _ = generator.eval()
15
+
16
+ return generator
17
+
18
+
19
+ def _denormalize(input: torch.Tensor) -> torch.Tensor:
20
+ return (input * 127.5) + 127.5
21
+
22
+
23
+ def generate_img(device, gan_model):
24
+ img_generator = load_img_generator("huggan/fastgan-few-shot-"+gan_model)
25
+ noise = torch.zeros(1, 256, 1, 1, device=device).normal_(0.0, 1.0)
26
+ with torch.no_grad():
27
+ gan_images, _ = img_generator(noise)
28
+ gan_image = _denormalize(gan_images.detach()).cpu().squeeze()
29
+ gan_image = gan_image.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
30
+ gan_image = Image.fromarray(gan_image)
31
+ return gan_image
layers.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://huggingface.co/huggan/fastgan-few-shot-fauvism-still-life
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.modules.batchnorm import BatchNorm2d
6
+ from torch.nn.utils import spectral_norm
7
+
8
+
9
+ class SpectralConv2d(nn.Module):
10
+
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__()
13
+ self._conv = spectral_norm(
14
+ nn.Conv2d(*args, **kwargs)
15
+ )
16
+
17
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
18
+ return self._conv(input)
19
+
20
+
21
+ class SpectralConvTranspose2d(nn.Module):
22
+
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__()
25
+ self._conv = spectral_norm(
26
+ nn.ConvTranspose2d(*args, **kwargs)
27
+ )
28
+
29
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
30
+ return self._conv(input)
31
+
32
+
33
+ class Noise(nn.Module):
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+ self._weight = nn.Parameter(
38
+ torch.zeros(1),
39
+ requires_grad=True,
40
+ )
41
+
42
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
43
+ batch_size, _, height, width = input.shape
44
+ noise = torch.randn(batch_size, 1, height, width, device=input.device)
45
+ return self._weight * noise + input
46
+
47
+
48
+ class InitLayer(nn.Module):
49
+
50
+ def __init__(self, in_channels: int,
51
+ out_channels: int):
52
+ super().__init__()
53
+
54
+ self._layers = nn.Sequential(
55
+ SpectralConvTranspose2d(
56
+ in_channels=in_channels,
57
+ out_channels=out_channels * 2,
58
+ kernel_size=4,
59
+ stride=1,
60
+ padding=0,
61
+ bias=False,
62
+ ),
63
+ nn.BatchNorm2d(num_features=out_channels * 2),
64
+ nn.GLU(dim=1),
65
+ )
66
+
67
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
68
+ return self._layers(input)
69
+
70
+
71
+ class SLEBlock(nn.Module):
72
+
73
+ def __init__(self, in_channels: int,
74
+ out_channels: int):
75
+ super().__init__()
76
+
77
+ self._layers = nn.Sequential(
78
+ nn.AdaptiveAvgPool2d(output_size=4),
79
+ SpectralConv2d(
80
+ in_channels=in_channels,
81
+ out_channels=out_channels,
82
+ kernel_size=4,
83
+ stride=1,
84
+ padding=0,
85
+ bias=False,
86
+ ),
87
+ nn.SiLU(),
88
+ SpectralConv2d(
89
+ in_channels=out_channels,
90
+ out_channels=out_channels,
91
+ kernel_size=1,
92
+ stride=1,
93
+ padding=0,
94
+ bias=False,
95
+ ),
96
+ nn.Sigmoid(),
97
+ )
98
+
99
+ def forward(self, low_dim: torch.Tensor,
100
+ high_dim: torch.Tensor) -> torch.Tensor:
101
+ return high_dim * self._layers(low_dim)
102
+
103
+
104
+ class UpsampleBlockT1(nn.Module):
105
+
106
+ def __init__(self, in_channels: int,
107
+ out_channels: int):
108
+ super().__init__()
109
+
110
+ self._layers = nn.Sequential(
111
+ nn.Upsample(scale_factor=2, mode='nearest'),
112
+ SpectralConv2d(
113
+ in_channels=in_channels,
114
+ out_channels=out_channels * 2,
115
+ kernel_size=3,
116
+ stride=1,
117
+ padding='same',
118
+ bias=False,
119
+ ),
120
+ nn.BatchNorm2d(num_features=out_channels * 2),
121
+ nn.GLU(dim=1),
122
+ )
123
+
124
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
125
+ return self._layers(input)
126
+
127
+
128
+ class UpsampleBlockT2(nn.Module):
129
+
130
+ def __init__(self, in_channels: int,
131
+ out_channels: int):
132
+ super().__init__()
133
+
134
+ self._layers = nn.Sequential(
135
+ nn.Upsample(scale_factor=2, mode='nearest'),
136
+ SpectralConv2d(
137
+ in_channels=in_channels,
138
+ out_channels=out_channels * 2,
139
+ kernel_size=3,
140
+ stride=1,
141
+ padding='same',
142
+ bias=False,
143
+ ),
144
+ Noise(),
145
+ BatchNorm2d(num_features=out_channels * 2),
146
+ nn.GLU(dim=1),
147
+ SpectralConv2d(
148
+ in_channels=out_channels,
149
+ out_channels=out_channels * 2,
150
+ kernel_size=3,
151
+ stride=1,
152
+ padding='same',
153
+ bias=False,
154
+ ),
155
+ Noise(),
156
+ nn.BatchNorm2d(num_features=out_channels * 2),
157
+ nn.GLU(dim=1),
158
+ )
159
+
160
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
161
+ return self._layers(input)
162
+
163
+
164
+ class DownsampleBlockT1(nn.Module):
165
+
166
+ def __init__(self, in_channels: int,
167
+ out_channels: int):
168
+ super().__init__()
169
+
170
+ self._layers = nn.Sequential(
171
+ SpectralConv2d(
172
+ in_channels=in_channels,
173
+ out_channels=out_channels,
174
+ kernel_size=4,
175
+ stride=2,
176
+ padding=1,
177
+ bias=False,
178
+ ),
179
+ nn.BatchNorm2d(num_features=out_channels),
180
+ nn.LeakyReLU(negative_slope=0.2),
181
+ )
182
+
183
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
184
+ return self._layers(input)
185
+
186
+
187
+ class DownsampleBlockT2(nn.Module):
188
+
189
+ def __init__(self, in_channels: int,
190
+ out_channels: int):
191
+ super().__init__()
192
+
193
+ self._layers_1 = nn.Sequential(
194
+ SpectralConv2d(
195
+ in_channels=in_channels,
196
+ out_channels=out_channels,
197
+ kernel_size=4,
198
+ stride=2,
199
+ padding=1,
200
+ bias=False,
201
+ ),
202
+ nn.BatchNorm2d(num_features=out_channels),
203
+ nn.LeakyReLU(negative_slope=0.2),
204
+ SpectralConv2d(
205
+ in_channels=out_channels,
206
+ out_channels=out_channels,
207
+ kernel_size=3,
208
+ stride=1,
209
+ padding='same',
210
+ bias=False,
211
+ ),
212
+ nn.BatchNorm2d(num_features=out_channels),
213
+ nn.LeakyReLU(negative_slope=0.2),
214
+ )
215
+
216
+ self._layers_2 = nn.Sequential(
217
+ nn.AvgPool2d(
218
+ kernel_size=2,
219
+ stride=2,
220
+ ),
221
+ SpectralConv2d(
222
+ in_channels=in_channels,
223
+ out_channels=out_channels,
224
+ kernel_size=1,
225
+ stride=1,
226
+ padding=0,
227
+ bias=False,
228
+ ),
229
+ nn.BatchNorm2d(num_features=out_channels),
230
+ nn.LeakyReLU(negative_slope=0.2),
231
+ )
232
+
233
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
234
+ t1 = self._layers_1(input)
235
+ t2 = self._layers_2(input)
236
+ return (t1 + t2) / 2
237
+
238
+
239
+ class Decoder(nn.Module):
240
+
241
+ def __init__(self, in_channels: int,
242
+ out_channels: int):
243
+ super().__init__()
244
+
245
+ self._channels = {
246
+ 16: 128,
247
+ 32: 64,
248
+ 64: 64,
249
+ 128: 32,
250
+ 256: 16,
251
+ 512: 8,
252
+ 1024: 4,
253
+ }
254
+
255
+ self._layers = nn.Sequential(
256
+ nn.AdaptiveAvgPool2d(output_size=8),
257
+ UpsampleBlockT1(in_channels=in_channels, out_channels=self._channels[16]),
258
+ UpsampleBlockT1(in_channels=self._channels[16], out_channels=self._channels[32]),
259
+ UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64]),
260
+ UpsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[128]),
261
+ SpectralConv2d(
262
+ in_channels=self._channels[128],
263
+ out_channels=out_channels,
264
+ kernel_size=3,
265
+ stride=1,
266
+ padding='same',
267
+ bias=False,
268
+ ),
269
+ nn.Tanh(),
270
+ )
271
+
272
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
273
+ return self._layers(input)
models.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://huggingface.co/huggan/fastgan-few-shot-fauvism-still-life
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from typing import Any, Tuple, Union
6
+
7
+ from utils import (
8
+ ImageType,
9
+ crop_image_part,
10
+ )
11
+
12
+ from layers import (
13
+ SpectralConv2d,
14
+ InitLayer,
15
+ SLEBlock,
16
+ UpsampleBlockT1,
17
+ UpsampleBlockT2,
18
+ DownsampleBlockT1,
19
+ DownsampleBlockT2,
20
+ Decoder,
21
+ )
22
+
23
+ from huggan.pytorch.huggan_mixin import HugGANModelHubMixin
24
+
25
+
26
+ class Generator(nn.Module, HugGANModelHubMixin):
27
+
28
+ def __init__(self, in_channels: int,
29
+ out_channels: int):
30
+ super().__init__()
31
+
32
+ self._channels = {
33
+ 4: 1024,
34
+ 8: 512,
35
+ 16: 256,
36
+ 32: 128,
37
+ 64: 128,
38
+ 128: 64,
39
+ 256: 32,
40
+ 512: 16,
41
+ 1024: 8,
42
+ }
43
+
44
+ self._init = InitLayer(
45
+ in_channels=in_channels,
46
+ out_channels=self._channels[4],
47
+ )
48
+
49
+ self._upsample_8 = UpsampleBlockT2(in_channels=self._channels[4], out_channels=self._channels[8] )
50
+ self._upsample_16 = UpsampleBlockT1(in_channels=self._channels[8], out_channels=self._channels[16] )
51
+ self._upsample_32 = UpsampleBlockT2(in_channels=self._channels[16], out_channels=self._channels[32] )
52
+ self._upsample_64 = UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64] )
53
+ self._upsample_128 = UpsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[128] )
54
+ self._upsample_256 = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] )
55
+ self._upsample_512 = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] )
56
+ self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024])
57
+
58
+ self._sle_64 = SLEBlock(in_channels=self._channels[4], out_channels=self._channels[64] )
59
+ self._sle_128 = SLEBlock(in_channels=self._channels[8], out_channels=self._channels[128])
60
+ self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256])
61
+ self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512])
62
+
63
+ self._out_128 = nn.Sequential(
64
+ SpectralConv2d(
65
+ in_channels=self._channels[128],
66
+ out_channels=out_channels,
67
+ kernel_size=1,
68
+ stride=1,
69
+ padding='same',
70
+ bias=False,
71
+ ),
72
+ nn.Tanh(),
73
+ )
74
+
75
+ self._out_1024 = nn.Sequential(
76
+ SpectralConv2d(
77
+ in_channels=self._channels[1024],
78
+ out_channels=out_channels,
79
+ kernel_size=3,
80
+ stride=1,
81
+ padding='same',
82
+ bias=False,
83
+ ),
84
+ nn.Tanh(),
85
+ )
86
+
87
+ def forward(self, input: torch.Tensor) -> \
88
+ Tuple[torch.Tensor, torch.Tensor]:
89
+ size_4 = self._init(input)
90
+ size_8 = self._upsample_8(size_4)
91
+ size_16 = self._upsample_16(size_8)
92
+ size_32 = self._upsample_32(size_16)
93
+
94
+ size_64 = self._sle_64 (size_4, self._upsample_64 (size_32) )
95
+ size_128 = self._sle_128(size_8, self._upsample_128(size_64) )
96
+ size_256 = self._sle_256(size_16, self._upsample_256(size_128))
97
+ size_512 = self._sle_512(size_32, self._upsample_512(size_256))
98
+
99
+ size_1024 = self._upsample_1024(size_512)
100
+
101
+ out_128 = self._out_128 (size_128)
102
+ out_1024 = self._out_1024(size_1024)
103
+ return out_1024, out_128
104
+
105
+
106
+ class Discriminrator(nn.Module, HugGANModelHubMixin):
107
+
108
+ def __init__(self, in_channels: int):
109
+ super().__init__()
110
+
111
+ self._channels = {
112
+ 4: 1024,
113
+ 8: 512,
114
+ 16: 256,
115
+ 32: 128,
116
+ 64: 128,
117
+ 128: 64,
118
+ 256: 32,
119
+ 512: 16,
120
+ 1024: 8,
121
+ }
122
+
123
+ self._init = nn.Sequential(
124
+ SpectralConv2d(
125
+ in_channels=in_channels,
126
+ out_channels=self._channels[1024],
127
+ kernel_size=4,
128
+ stride=2,
129
+ padding=1,
130
+ bias=False,
131
+ ),
132
+ nn.LeakyReLU(negative_slope=0.2),
133
+ SpectralConv2d(
134
+ in_channels=self._channels[1024],
135
+ out_channels=self._channels[512],
136
+ kernel_size=4,
137
+ stride=2,
138
+ padding=1,
139
+ bias=False,
140
+ ),
141
+ nn.BatchNorm2d(num_features=self._channels[512]),
142
+ nn.LeakyReLU(negative_slope=0.2),
143
+ )
144
+
145
+ self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256])
146
+ self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128])
147
+ self._downsample_64 = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] )
148
+ self._downsample_32 = DownsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[32] )
149
+ self._downsample_16 = DownsampleBlockT2(in_channels=self._channels[32], out_channels=self._channels[16] )
150
+
151
+ self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64])
152
+ self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32])
153
+ self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16])
154
+
155
+ self._small_track = nn.Sequential(
156
+ SpectralConv2d(
157
+ in_channels=in_channels,
158
+ out_channels=self._channels[256],
159
+ kernel_size=4,
160
+ stride=2,
161
+ padding=1,
162
+ bias=False,
163
+ ),
164
+ nn.LeakyReLU(negative_slope=0.2),
165
+ DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]),
166
+ DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ),
167
+ DownsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[32] ),
168
+ )
169
+
170
+ self._features_large = nn.Sequential(
171
+ SpectralConv2d(
172
+ in_channels=self._channels[16] ,
173
+ out_channels=self._channels[8],
174
+ kernel_size=1,
175
+ stride=1,
176
+ padding=0,
177
+ bias=False,
178
+ ),
179
+ nn.BatchNorm2d(num_features=self._channels[8]),
180
+ nn.LeakyReLU(negative_slope=0.2),
181
+ SpectralConv2d(
182
+ in_channels=self._channels[8],
183
+ out_channels=1,
184
+ kernel_size=4,
185
+ stride=1,
186
+ padding=0,
187
+ bias=False,
188
+ )
189
+ )
190
+
191
+ self._features_small = nn.Sequential(
192
+ SpectralConv2d(
193
+ in_channels=self._channels[32],
194
+ out_channels=1,
195
+ kernel_size=4,
196
+ stride=1,
197
+ padding=0,
198
+ bias=False,
199
+ ),
200
+ )
201
+
202
+ self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3)
203
+ self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3)
204
+ self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3)
205
+
206
+ def forward(self, images_1024: torch.Tensor,
207
+ images_128: torch.Tensor,
208
+ image_type: ImageType) -> \
209
+ Union[
210
+ torch.Tensor,
211
+ Tuple[torch.Tensor, Tuple[Any, Any, Any]]
212
+ ]:
213
+ # large track
214
+
215
+ down_512 = self._init(images_1024)
216
+ down_256 = self._downsample_256(down_512)
217
+ down_128 = self._downsample_128(down_256)
218
+
219
+ down_64 = self._downsample_64(down_128)
220
+ down_64 = self._sle_64(down_512, down_64)
221
+
222
+ down_32 = self._downsample_32(down_64)
223
+ down_32 = self._sle_32(down_256, down_32)
224
+
225
+ down_16 = self._downsample_16(down_32)
226
+ down_16 = self._sle_16(down_128, down_16)
227
+
228
+ # small track
229
+
230
+ down_small = self._small_track(images_128)
231
+
232
+ # features
233
+
234
+ features_large = self._features_large(down_16).view(-1)
235
+ features_small = self._features_small(down_small).view(-1)
236
+ features = torch.cat([features_large, features_small], dim=0)
237
+
238
+ # decoder
239
+
240
+ if image_type != ImageType.FAKE:
241
+ dec_large = self._decoder_large(down_16)
242
+ dec_small = self._decoder_small(down_small)
243
+ dec_piece = self._decoder_piece(crop_image_part(down_32, image_type))
244
+ return features, (dec_large, dec_small, dec_piece)
245
+
246
+ return features
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ git+https://github.com/huggingface/community-events@main
4
+ gradio
text_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def wrap_text(generated_text):
2
+ wrapping_text = ""
3
+ current_line_length = 0
4
+ print(generated_text)
5
+ if "-" in generated_text:
6
+ quote, author = generated_text.split("-")
7
+ elif "―" in generated_text:
8
+ quote, author = generated_text.split("―")
9
+ else:
10
+ quote = generated_text
11
+ author = None
12
+ for word in quote.split(" "):
13
+ if current_line_length >= 20:
14
+ wrapping_text += f"\n{word} "
15
+ current_line_length = len(word)
16
+ else:
17
+ wrapping_text += f"{word} "
18
+ current_line_length += len(word)
19
+ if author is not None:
20
+ wrapping_text += f"\n- {author}"
21
+ return wrapping_text
22
+
23
+
24
+ def compute_text_position(wrapped_text):
25
+ img_height = 1024
26
+ line_height_in_px = 74 # roughly estimated
27
+ margin_bottom = 100 # align text close to the bottom, leaving this many pixels free
28
+ n_lines = wrapped_text.count("\n") + 1
29
+ text_height = n_lines * line_height_in_px
30
+ text_pos = img_height - margin_bottom - text_height
31
+ return text_pos
utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://huggingface.co/huggan/fastgan-few-shot-fauvism-still-life
2
+ import torch
3
+ import torch.nn as nn
4
+ from enum import Enum
5
+
6
+ import base64
7
+ import json
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ import requests
11
+ import re
12
+
13
+ class ImageType(Enum):
14
+ REAL_UP_L = 0
15
+ REAL_UP_R = 1
16
+ REAL_DOWN_R = 2
17
+ REAL_DOWN_L = 3
18
+ FAKE = 4
19
+
20
+
21
+ def crop_image_part(image: torch.Tensor,
22
+ part: ImageType) -> torch.Tensor:
23
+ size = image.shape[2] // 2
24
+
25
+ if part == ImageType.REAL_UP_L:
26
+ return image[:, :, :size, :size]
27
+
28
+ elif part == ImageType.REAL_UP_R:
29
+ return image[:, :, :size, size:]
30
+
31
+ elif part == ImageType.REAL_DOWN_L:
32
+ return image[:, :, size:, :size]
33
+
34
+ elif part == ImageType.REAL_DOWN_R:
35
+ return image[:, :, size:, size:]
36
+
37
+ else:
38
+ raise ValueError('invalid part')
39
+
40
+
41
+ def init_weights(module: nn.Module):
42
+ if isinstance(module, nn.Conv2d):
43
+ torch.nn.init.normal_(module.weight, 0.0, 0.02)
44
+
45
+ if isinstance(module, nn.BatchNorm2d):
46
+ torch.nn.init.normal_(module.weight, 1.0, 0.02)
47
+ module.bias.data.fill_(0)
48
+
49
+ def load_image_from_local(image_path, image_resize=None):
50
+ image = Image.open(image_path)
51
+
52
+ if isinstance(image_resize, tuple):
53
+ image = image.resize(image_resize)
54
+ return image
55
+
56
+ def load_image_from_url(image_url, rgba_mode=False, image_resize=None, default_image=None):
57
+ try:
58
+ image = Image.open(requests.get(image_url, stream=True).raw)
59
+
60
+ if rgba_mode:
61
+ image = image.convert("RGBA")
62
+
63
+ if isinstance(image_resize, tuple):
64
+ image = image.resize(image_resize)
65
+
66
+ except Exception as e:
67
+ image = None
68
+ if default_image:
69
+ image = load_image_from_local(default_image, image_resize=image_resize)
70
+
71
+ return image
72
+
73
+ def image_to_base64(image_array):
74
+ buffered = BytesIO()
75
+ image_array.save(buffered, format="PNG")
76
+ image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
77
+ return f"data:image/png;base64, {image_b64}"