paulai commited on
Commit
365a183
·
verified ·
1 Parent(s): 091d7d2
Files changed (1) hide show
  1. scripts/evaluation/funcs.py +196 -194
scripts/evaluation/funcs.py CHANGED
@@ -1,194 +1,196 @@
1
- import os, sys, glob
2
- import numpy as np
3
- from collections import OrderedDict
4
- from decord import VideoReader, cpu
5
- import cv2
6
-
7
- import torch
8
- import torchvision
9
- sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10
- from lvdm.models.samplers.ddim import DDIMSampler
11
-
12
-
13
- def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
14
- cfg_scale=1.0, temporal_cfg_scale=None, **kwargs):
15
- ddim_sampler = DDIMSampler(model)
16
- uncond_type = model.uncond_type
17
- batch_size = noise_shape[0]
18
-
19
- ## construct unconditional guidance
20
- if cfg_scale != 1.0:
21
- if uncond_type == "empty_seq":
22
- prompts = batch_size * [""]
23
- #prompts = N * T * [""] ## if is_imgbatch=True
24
- uc_emb = model.get_learned_conditioning(prompts)
25
- elif uncond_type == "zero_embed":
26
- c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
27
- uc_emb = torch.zeros_like(c_emb)
28
-
29
- ## process image embedding token
30
- if hasattr(model, 'embedder'):
31
- uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
32
- ## img: b c h w >> b l c
33
- uc_img = model.get_image_embeds(uc_img)
34
- uc_emb = torch.cat([uc_emb, uc_img], dim=1)
35
-
36
- if isinstance(cond, dict):
37
- uc = {key:cond[key] for key in cond.keys()}
38
- uc.update({'c_crossattn': [uc_emb]})
39
- else:
40
- uc = uc_emb
41
- else:
42
- uc = None
43
-
44
- x_T = None
45
- batch_variants = []
46
- #batch_variants1, batch_variants2 = [], []
47
- for _ in range(n_samples):
48
- if ddim_sampler is not None:
49
- kwargs.update({"clean_cond": True})
50
- samples, _ = ddim_sampler.sample(S=ddim_steps,
51
- conditioning=cond,
52
- batch_size=noise_shape[0],
53
- shape=noise_shape[1:],
54
- verbose=False,
55
- unconditional_guidance_scale=cfg_scale,
56
- unconditional_conditioning=uc,
57
- eta=ddim_eta,
58
- temporal_length=noise_shape[2],
59
- conditional_guidance_scale_temporal=temporal_cfg_scale,
60
- x_T=x_T,
61
- **kwargs
62
- )
63
- ## reconstruct from latent to pixel space
64
- batch_images = model.decode_first_stage_2DAE(samples)
65
- batch_variants.append(batch_images)
66
- ## batch, <samples>, c, t, h, w
67
- batch_variants = torch.stack(batch_variants, dim=1)
68
- return batch_variants
69
-
70
-
71
- def get_filelist(data_dir, ext='*'):
72
- file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
73
- file_list.sort()
74
- return file_list
75
-
76
- def get_dirlist(path):
77
- list = []
78
- if (os.path.exists(path)):
79
- files = os.listdir(path)
80
- for file in files:
81
- m = os.path.join(path,file)
82
- if (os.path.isdir(m)):
83
- list.append(m)
84
- list.sort()
85
- return list
86
-
87
-
88
- def load_model_checkpoint(model, ckpt):
89
- def load_checkpoint(model, ckpt, full_strict):
90
- state_dict = torch.load(ckpt, map_location="cpu")
91
- try:
92
- ## deepspeed
93
- new_pl_sd = OrderedDict()
94
- for key in state_dict['module'].keys():
95
- new_pl_sd[key[16:]]=state_dict['module'][key]
96
- model.load_state_dict(new_pl_sd, strict=full_strict)
97
- except:
98
- if "state_dict" in list(state_dict.keys()):
99
- state_dict = state_dict["state_dict"]
100
- model.load_state_dict(state_dict, strict=full_strict)
101
- return model
102
- load_checkpoint(model, ckpt, full_strict=True)
103
- print('>>> model checkpoint loaded.')
104
- return model
105
-
106
-
107
- def load_prompts(prompt_file):
108
- f = open(prompt_file, 'r')
109
- prompt_list = []
110
- for idx, line in enumerate(f.readlines()):
111
- l = line.strip()
112
- if len(l) != 0:
113
- prompt_list.append(l)
114
- f.close()
115
- return prompt_list
116
-
117
-
118
- def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
119
- '''
120
- Notice about some special cases:
121
- 1. video_frames=-1 means to take all the frames (with fs=1)
122
- 2. when the total video frames is less than required, padding strategy will be used (repreated last frame)
123
- '''
124
- fps_list = []
125
- batch_tensor = []
126
- assert frame_stride > 0, "valid frame stride should be a positive interge!"
127
- for filepath in filepath_list:
128
- padding_num = 0
129
- vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
130
- fps = vidreader.get_avg_fps()
131
- total_frames = len(vidreader)
132
- max_valid_frames = (total_frames-1) // frame_stride + 1
133
- if video_frames < 0:
134
- ## all frames are collected: fs=1 is a must
135
- required_frames = total_frames
136
- frame_stride = 1
137
- else:
138
- required_frames = video_frames
139
- query_frames = min(required_frames, max_valid_frames)
140
- frame_indices = [frame_stride*i for i in range(query_frames)]
141
-
142
- ## [t,h,w,c] -> [c,t,h,w]
143
- frames = vidreader.get_batch(frame_indices)
144
- frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
145
- frame_tensor = (frame_tensor / 255. - 0.5) * 2
146
- if max_valid_frames < required_frames:
147
- padding_num = required_frames - max_valid_frames
148
- frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
149
- print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
150
- batch_tensor.append(frame_tensor)
151
- sample_fps = int(fps/frame_stride)
152
- fps_list.append(sample_fps)
153
-
154
- return torch.stack(batch_tensor, dim=0)
155
-
156
- from PIL import Image
157
- def load_image_batch(filepath_list, image_size=(256,256)):
158
- batch_tensor = []
159
- for filepath in filepath_list:
160
- _, filename = os.path.split(filepath)
161
- _, ext = os.path.splitext(filename)
162
- if ext == '.mp4':
163
- vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
164
- frame = vidreader.get_batch([0])
165
- img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
166
- elif ext == '.png' or ext == '.jpg':
167
- img = Image.open(filepath).convert("RGB")
168
- rgb_img = np.array(img, np.float32)
169
- #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
170
- #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
171
- rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
172
- img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
173
- else:
174
- print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
175
- raise NotImplementedError
176
- img_tensor = (img_tensor / 255. - 0.5) * 2
177
- batch_tensor.append(img_tensor)
178
- return torch.stack(batch_tensor, dim=0)
179
-
180
-
181
- def save_videos(batch_tensors, savedir, filenames, fps=10):
182
- # b,samples,c,t,h,w
183
- n_samples = batch_tensors.shape[1]
184
- for idx, vid_tensor in enumerate(batch_tensors):
185
- video = vid_tensor.detach().cpu()
186
- video = torch.clamp(video.float(), -1., 1.)
187
- video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
188
- frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
189
- grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
190
- grid = (grid + 1.0) / 2.0
191
- grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
192
- savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
193
- torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
194
-
 
 
 
1
+ import os, sys, glob
2
+ import numpy as np
3
+ from collections import OrderedDict
4
+ from decord import VideoReader, cpu
5
+ import cv2
6
+
7
+ import torch
8
+ import torchvision
9
+ sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10
+ from lvdm.models.samplers.ddim import DDIMSampler
11
+
12
+
13
+ def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
14
+ cfg_scale=1.0, temporal_cfg_scale=None, neg_prompts=None, **kwargs):
15
+ ddim_sampler = DDIMSampler(model)
16
+ uncond_type = model.uncond_type
17
+ batch_size = noise_shape[0]
18
+
19
+ ## construct unconditional guidance
20
+ if cfg_scale != 1.0:
21
+ if uncond_type == "empty_seq":
22
+ prompts = batch_size * [""]
23
+ if neg_prompts is not None:
24
+ prompts = batch_size * [neg_prompts]
25
+ #prompts = N * T * [""] ## if is_imgbatch=True
26
+ uc_emb = model.get_learned_conditioning(prompts)
27
+ elif uncond_type == "zero_embed":
28
+ c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
29
+ uc_emb = torch.zeros_like(c_emb)
30
+
31
+ ## process image embedding token
32
+ if hasattr(model, 'embedder'):
33
+ uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
34
+ ## img: b c h w >> b l c
35
+ uc_img = model.get_image_embeds(uc_img)
36
+ uc_emb = torch.cat([uc_emb, uc_img], dim=1)
37
+
38
+ if isinstance(cond, dict):
39
+ uc = {key:cond[key] for key in cond.keys()}
40
+ uc.update({'c_crossattn': [uc_emb]})
41
+ else:
42
+ uc = uc_emb
43
+ else:
44
+ uc = None
45
+
46
+ x_T = None
47
+ batch_variants = []
48
+ #batch_variants1, batch_variants2 = [], []
49
+ for _ in range(n_samples):
50
+ if ddim_sampler is not None:
51
+ kwargs.update({"clean_cond": True})
52
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
53
+ conditioning=cond,
54
+ batch_size=noise_shape[0],
55
+ shape=noise_shape[1:],
56
+ verbose=False,
57
+ unconditional_guidance_scale=cfg_scale,
58
+ unconditional_conditioning=uc,
59
+ eta=ddim_eta,
60
+ temporal_length=noise_shape[2],
61
+ conditional_guidance_scale_temporal=temporal_cfg_scale,
62
+ x_T=x_T,
63
+ **kwargs
64
+ )
65
+ ## reconstruct from latent to pixel space
66
+ batch_images = model.decode_first_stage_2DAE(samples)
67
+ batch_variants.append(batch_images)
68
+ ## batch, <samples>, c, t, h, w
69
+ batch_variants = torch.stack(batch_variants, dim=1)
70
+ return batch_variants
71
+
72
+
73
+ def get_filelist(data_dir, ext='*'):
74
+ file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
75
+ file_list.sort()
76
+ return file_list
77
+
78
+ def get_dirlist(path):
79
+ list = []
80
+ if (os.path.exists(path)):
81
+ files = os.listdir(path)
82
+ for file in files:
83
+ m = os.path.join(path,file)
84
+ if (os.path.isdir(m)):
85
+ list.append(m)
86
+ list.sort()
87
+ return list
88
+
89
+
90
+ def load_model_checkpoint(model, ckpt):
91
+ def load_checkpoint(model, ckpt, full_strict):
92
+ state_dict = torch.load(ckpt, map_location="cpu")
93
+ try:
94
+ ## deepspeed
95
+ new_pl_sd = OrderedDict()
96
+ for key in state_dict['module'].keys():
97
+ new_pl_sd[key[16:]]=state_dict['module'][key]
98
+ model.load_state_dict(new_pl_sd, strict=full_strict)
99
+ except:
100
+ if "state_dict" in list(state_dict.keys()):
101
+ state_dict = state_dict["state_dict"]
102
+ model.load_state_dict(state_dict, strict=full_strict)
103
+ return model
104
+ load_checkpoint(model, ckpt, full_strict=True)
105
+ print('>>> model checkpoint loaded.')
106
+ return model
107
+
108
+
109
+ def load_prompts(prompt_file):
110
+ f = open(prompt_file, 'r')
111
+ prompt_list = []
112
+ for idx, line in enumerate(f.readlines()):
113
+ l = line.strip()
114
+ if len(l) != 0:
115
+ prompt_list.append(l)
116
+ f.close()
117
+ return prompt_list
118
+
119
+
120
+ def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
121
+ '''
122
+ Notice about some special cases:
123
+ 1. video_frames=-1 means to take all the frames (with fs=1)
124
+ 2. when the total video frames is less than required, padding strategy will be used (repreated last frame)
125
+ '''
126
+ fps_list = []
127
+ batch_tensor = []
128
+ assert frame_stride > 0, "valid frame stride should be a positive interge!"
129
+ for filepath in filepath_list:
130
+ padding_num = 0
131
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
132
+ fps = vidreader.get_avg_fps()
133
+ total_frames = len(vidreader)
134
+ max_valid_frames = (total_frames-1) // frame_stride + 1
135
+ if video_frames < 0:
136
+ ## all frames are collected: fs=1 is a must
137
+ required_frames = total_frames
138
+ frame_stride = 1
139
+ else:
140
+ required_frames = video_frames
141
+ query_frames = min(required_frames, max_valid_frames)
142
+ frame_indices = [frame_stride*i for i in range(query_frames)]
143
+
144
+ ## [t,h,w,c] -> [c,t,h,w]
145
+ frames = vidreader.get_batch(frame_indices)
146
+ frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
147
+ frame_tensor = (frame_tensor / 255. - 0.5) * 2
148
+ if max_valid_frames < required_frames:
149
+ padding_num = required_frames - max_valid_frames
150
+ frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
151
+ print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
152
+ batch_tensor.append(frame_tensor)
153
+ sample_fps = int(fps/frame_stride)
154
+ fps_list.append(sample_fps)
155
+
156
+ return torch.stack(batch_tensor, dim=0)
157
+
158
+ from PIL import Image
159
+ def load_image_batch(filepath_list, image_size=(256,256)):
160
+ batch_tensor = []
161
+ for filepath in filepath_list:
162
+ _, filename = os.path.split(filepath)
163
+ _, ext = os.path.splitext(filename)
164
+ if ext == '.mp4':
165
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
166
+ frame = vidreader.get_batch([0])
167
+ img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
168
+ elif ext == '.png' or ext == '.jpg':
169
+ img = Image.open(filepath).convert("RGB")
170
+ rgb_img = np.array(img, np.float32)
171
+ #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
172
+ #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
173
+ rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
174
+ img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
175
+ else:
176
+ print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
177
+ raise NotImplementedError
178
+ img_tensor = (img_tensor / 255. - 0.5) * 2
179
+ batch_tensor.append(img_tensor)
180
+ return torch.stack(batch_tensor, dim=0)
181
+
182
+
183
+ def save_videos(batch_tensors, savedir, filenames, fps=10):
184
+ # b,samples,c,t,h,w
185
+ n_samples = batch_tensors.shape[1]
186
+ for idx, vid_tensor in enumerate(batch_tensors):
187
+ video = vid_tensor.detach().cpu()
188
+ video = torch.clamp(video.float(), -1., 1.)
189
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
190
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
191
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
192
+ grid = (grid + 1.0) / 2.0
193
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
194
+ savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
195
+ torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
196
+