Wryley1234 commited on
Commit
ec70094
1 Parent(s): c956a18

Coding key

Browse files
Files changed (1) hide show
  1. Unknown +227 -0
Unknown ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torchvision.datasets.utils import download_url
3
+ from ldm.util import instantiate_from_config
4
+ import torch
5
+ import os
6
+ # todo ?
7
+ from google.colab import files
8
+ from IPython.display import Image as ipyimg
9
+ import ipywidgets as widgets
10
+ from PIL import Image
11
+ from numpy import asarray
12
+ from einops import rearrange, repeat
13
+ import torch, torchvision
14
+ from ldm.models.diffusion.ddim import DDIMSampler
15
+ from ldm.util import ismap
16
+ import time
17
+ from omegaconf import OmegaConf
18
+
19
+
20
+ def download_models(mode):
21
+
22
+ if mode == "superresolution":
23
+ # this is the small bsr light model
24
+ url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'
25
+ url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'
26
+
27
+ path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml'
28
+ path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt'
29
+
30
+ download_url(url_conf, path_conf)
31
+ download_url(url_ckpt, path_ckpt)
32
+
33
+ path_conf = path_conf + '/?dl=1' # fix it
34
+ path_ckpt = path_ckpt + '/?dl=1' # fix it
35
+ return path_conf, path_ckpt
36
+
37
+ else:
38
+ raise NotImplementedError
39
+
40
+
41
+ def load_model_from_config(config, ckpt):
42
+ print(f"Loading model from {ckpt}")
43
+ pl_sd = torch.load(ckpt, map_location="cpu")
44
+ global_step = pl_sd["global_step"]
45
+ sd = pl_sd["state_dict"]
46
+ model = instantiate_from_config(config.model)
47
+ m, u = model.load_state_dict(sd, strict=False)
48
+ model.cuda()
49
+ model.eval()
50
+ return {"model": model}, global_step
51
+
52
+
53
+ def get_model(mode):
54
+ path_conf, path_ckpt = download_models(mode)
55
+ config = OmegaConf.load(path_conf)
56
+ model, step = load_model_from_config(config, path_ckpt)
57
+ return model
58
+
59
+
60
+ def get_custom_cond(mode):
61
+ dest = "data/example_conditioning"
62
+
63
+ if mode == "superresolution":
64
+ uploaded_img = files.upload()
65
+ filename = next(iter(uploaded_img))
66
+ name, filetype = filename.split(".") # todo assumes just one dot in name !
67
+ os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
68
+
69
+ elif mode == "text_conditional":
70
+ w = widgets.Text(value='A cake with cream!', disabled=True)
71
+ display(w)
72
+
73
+ with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f:
74
+ f.write(w.value)
75
+
76
+ elif mode == "class_conditional":
77
+ w = widgets.IntSlider(min=0, max=1000)
78
+ display(w)
79
+ with open(f"{dest}/{mode}/custom.txt", 'w') as f:
80
+ f.write(w.value)
81
+
82
+ else:
83
+ raise NotImplementedError(f"cond not implemented for mode{mode}")
84
+
85
+
86
+ def get_cond_options(mode):
87
+ path = "data/example_conditioning"
88
+ path = os.path.join(path, mode)
89
+ onlyfiles = [f for f in sorted(os.listdir(path))]
90
+ return path, onlyfiles
91
+
92
+
93
+ def select_cond_path(mode):
94
+ path = "data/example_conditioning" # todo
95
+ path = os.path.join(path, mode)
96
+ onlyfiles = [f for f in sorted(os.listdir(path))]
97
+
98
+ selected = widgets.RadioButtons(
99
+ options=onlyfiles,
100
+ description='Select conditioning:',
101
+ disabled=False
102
+ )
103
+ display(selected)
104
+ selected_path = os.path.join(path, selected.value)
105
+ return selected_path
106
+
107
+
108
+ def get_cond(mode, selected_path):
109
+ example = dict()
110
+ if mode == "superresolution":
111
+ up_f = 4
112
+ visualize_cond_img(selected_path)
113
+
114
+ c = Image.open(selected_path)
115
+ c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
116
+ c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
117
+ c_up = rearrange(c_up, '1 c h w -> 1 h w c')
118
+ c = rearrange(c, '1 c h w -> 1 h w c')
119
+ c = 2. * c - 1.
120
+
121
+ c = c.to(torch.device("cuda"))
122
+ example["LR_image"] = c
123
+ example["image"] = c_up
124
+
125
+ return example
126
+
127
+
128
+ def visualize_cond_img(path):
129
+ display(ipyimg(filename=path))
130
+
131
+
132
+ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
133
+
134
+ example = get_cond(task, selected_path)
135
+
136
+ save_intermediate_vid = False
137
+ n_runs = 1
138
+ masked = False
139
+ guider = None
140
+ ckwargs = None
141
+ mode = 'ddim'
142
+ ddim_use_x0_pred = False
143
+ temperature = 1.
144
+ eta = 1.
145
+ make_progrow = True
146
+ custom_shape = None
147
+
148
+ height, width = example["image"].shape[1:3]
149
+ split_input = height >= 128 and width >= 128
150
+
151
+ if split_input:
152
+ ks = 128
153
+ stride = 64
154
+ vqf = 4 #
155
+ model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
156
+ "vqf": vqf,
157
+ "patch_distributed_vq": True,
158
+ "tie_braker": False,
159
+ "clip_max_weight": 0.5,
160
+ "clip_min_weight": 0.01,
161
+ "clip_max_tie_weight": 0.5,
162
+ "clip_min_tie_weight": 0.01}
163
+ else:
164
+ if hasattr(model, "split_input_params"):
165
+ delattr(model, "split_input_params")
166
+
167
+ invert_mask = False
168
+
169
+ x_T = None
170
+ for n in range(n_runs):
171
+ if custom_shape is not None:
172
+ x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
173
+ x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])
174
+
175
+ logs = make_convolutional_sample(example, model,
176
+ mode=mode, custom_steps=custom_steps,
177
+ eta=eta, swap_mode=False , masked=masked,
178
+ invert_mask=invert_mask, quantize_x0=False,
179
+ custom_schedule=None, decode_interval=10,
180
+ resize_enabled=resize_enabled, custom_shape=custom_shape,
181
+ temperature=temperature, noise_dropout=0.,
182
+ corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
183
+ make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
184
+ )
185
+ return logs
186
+
187
+
188
+ @torch.no_grad()
189
+ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
190
+ mask=None, x0=None, quantize_x0=False, img_callback=None,
191
+ temperature=1., noise_dropout=0., score_corrector=None,
192
+ corrector_kwargs=None, x_T=None, log_every_t=None
193
+ ):
194
+
195
+ ddim = DDIMSampler(model)
196
+ bs = shape[0] # dont know where this comes from but wayne
197
+ shape = shape[1:] # cut batch dim
198
+ print(f"Sampling with eta = {eta}; steps: {steps}")
199
+ samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
200
+ normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
201
+ mask=mask, x0=x0, temperature=temperature, verbose=False,
202
+ score_corrector=score_corrector,
203
+ corrector_kwargs=corrector_kwargs, x_T=x_T)
204
+
205
+ return samples, intermediates
206
+
207
+
208
+ @torch.no_grad()
209
+ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False,
210
+ invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,
211
+ resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
212
+ corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):
213
+ log = dict()
214
+
215
+ z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
216
+ return_first_stage_outputs=True,
217
+ force_c_encode=not (hasattr(model, 'split_input_params')
218
+ and model.cond_stage_key == 'coordinates_bbox'),
219
+ return_original_cond=True)
220
+
221
+ log_every_t = 1 if save_intermediate_vid else None
222
+
223
+ if custom_shape is not None:
224
+ z = torch.randn(custom_shape)
225
+ print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
226
+
227
+ z0 = None