johnowhitaker commited on
Commit
d171496
·
1 Parent(s): 36d8a22

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, glob
3
+ from functools import partial
4
+ import glob
5
+ import torch
6
+ from torch import nn
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
11
+
12
+ class RuleCA(nn.Module):
13
+ def __init__(self, hidden_n=6, rule_channels=4, zero_w2=True, device=device):
14
+ super().__init__()
15
+ # The hard-coded filters:
16
+ self.filters = torch.stack([torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]]),
17
+ torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]),
18
+ torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]).T,
19
+ torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])]).to(device)
20
+ self.chn = 4
21
+ self.rule_channels = rule_channels
22
+ self.w1 = nn.Conv2d(4*4+rule_channels, hidden_n, 1).to(device)
23
+ self.relu = nn.ReLU()
24
+ self.w2 = nn.Conv2d(hidden_n, 4, 1, bias=False).to(device)
25
+ if zero_w2:
26
+ self.w2.weight.data.zero_()
27
+ self.device = device
28
+
29
+ def perchannel_conv(self, x, filters):
30
+ '''filters: [filter_n, h, w]'''
31
+ b, ch, h, w = x.shape
32
+ y = x.reshape(b*ch, 1, h, w)
33
+ y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
34
+ y = torch.nn.functional.conv2d(y, filters[:,None])
35
+ return y.reshape(b, -1, h, w)
36
+
37
+ def forward(self, x, rule=0, update_rate=0.5):
38
+ b, ch, xsz, ysz = x.shape
39
+ rule_grid = torch.zeros(b, self.rule_channels, xsz, ysz).to(self.device)
40
+ rule_grid[:,rule] = 1
41
+ y = self.perchannel_conv(x, self.filters) # Apply the filters
42
+ y = torch.cat([y, rule_grid], dim=1)
43
+ y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
44
+ b, c, h, w = y.shape
45
+ update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
46
+ return x+y*update_mask
47
+
48
+ def forward_w_rule_grid(self, x, rule_grid, update_rate=0.5):
49
+ y = self.perchannel_conv(x, self.filters) # Apply the filters
50
+ y = torch.cat([y, rule_grid], dim=1)
51
+ y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
52
+ b, c, h, w = y.shape
53
+ update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
54
+ return x+y*update_mask
55
+
56
+ def to_rgb(self, x):
57
+ # TODO: rename this to_rgb & explain
58
+ return x[...,:3,:,:]+0.5
59
+
60
+ def seed(self, n, sz=128):
61
+ """Initializes n 'grids', size sz. In this case all 0s."""
62
+ return torch.zeros(n, self.chn, sz, sz).to(self.device)
63
+
64
+ def to_frames(video_file):
65
+ os.system('rm -r guide_frames;mkdir guide_frames')
66
+ os.system(f"ffmpeg -i {video_file} guide_frames/%04d.jpg")
67
+
68
+ def update(preset, enhance, video_file):
69
+
70
+ # Load presets
71
+ ca = RuleCA(hidden_n=32, rule_channels=3)
72
+ ca_fn = ''
73
+ if preset == 'Glowing Crystals':
74
+ ca_fn = 'glowing_crystals.pt'
75
+ elif preset == 'Rainbow Diamonds':
76
+ ca_fn = 'rainbow_diamonds.pt'
77
+ elif preset == 'Dark Diamonds':
78
+ ca_fn = 'dark_diamonds.pt'
79
+ elif preset == 'Dragon Scales':
80
+ ca = RuleCA(hidden_n=16, rule_channels=3)
81
+ ca_fn = 'dragon_scales.pt'
82
+
83
+ ca.load_state_dict(torch.load(ca_fn, map_location=device))
84
+
85
+ # Get video frames
86
+ to_frames(video_file)
87
+
88
+ size=(426, 240)
89
+ vid_size = Image.open(f'guide_frames/0001.jpg').size
90
+ if vid_size[0]>vid_size[1]:
91
+ size = (256, int(256*(vid_size[1]/vid_size[0])))
92
+ else:
93
+ size = (int(256*(vid_size[0]/vid_size[1])), 256)
94
+
95
+ # Starting grid
96
+ x = torch.zeros(1, 4, size[1], size[0]).to(ca.device)
97
+ os.system("rm -r steps;mkdir steps")
98
+ for i in range(2*len(glob.glob('guide_frames/*.jpg'))-1):
99
+ # load frame
100
+ im = Image.open(f'guide_frames/{i//2+1:04}.jpg').resize(size)
101
+
102
+ # make rule grid
103
+ rule_grid = torch.tensor(np.array(im)/255).permute(2, 0, 1).unsqueeze(0).to(ca.device)
104
+ if enhance:
105
+ rule_grid = rule_grid * 2 - 0.3 # Add * 2 - 0.3 to 'enhance' an effect
106
+
107
+ # Apply the updates
108
+ with torch.no_grad():
109
+ x = ca.forward_w_rule_grid(x, rule_grid.float())
110
+ if i%2==0:
111
+ img = ca.to_rgb(x).detach().cpu().clip(0, 1).squeeze().permute(1, 2, 0)
112
+ img = Image.fromarray(np.array(img*255).astype(np.uint8))
113
+ img.save(f'steps/{i//2:05}.jpeg')
114
+
115
+ # Write output video from saved frames
116
+ os.system("ffmpeg -y -v 0 -framerate 24 -i steps/%05d.jpeg video.mp4")
117
+ return 'video.mp4'
118
+
119
+
120
+ demo = gr.Blocks()
121
+
122
+ with demo:
123
+ gr.Markdown("Start typing below and then click **Run** to see the output.")
124
+ with gr.Row():
125
+ preset = gr.Dropdown(['Glowing Crystals', 'Rainbow Diamonds', 'Dark Diamonds', 'Dragon Scales'], label='Preset')
126
+ enhance = gr.Checkbox(label='Rescale inputs (more extreme results)')
127
+ with gr.Row():
128
+ inp = gr.Video(format='mp4', source='upload', label="Input video (ideally <30s)")
129
+ out = gr.Video(label="Output")
130
+ btn = gr.Button("Run")
131
+ btn.click(fn=update, inputs=[preset, enhance, inp], outputs=out)
132
+
133
+ demo.launch(enable_queue=True)