Spaces:
Running
on
A10G
Running
on
A10G
File size: 1,539 Bytes
eb9a9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
def flatten_grid(x, grid_size=[2, 2]):
'''
x: B x C x H x W
'''
B, C, H, W = x.size()
hs, ws = grid_size
img_h = H // hs
flattened = torch.cat(torch.split(x, img_h, dim=2), dim=-1)
return flattened
def unflatten_grid(x, grid_size=[2,2]):
'''
x: B x C x H x W
'''
B, C, H, W = x.size()
hs, ws = grid_size
img_w = W // (ws)
unflattened = torch.cat(torch.split(x, img_w, dim=3), dim=-2)
return unflattened
def prepare_key_grid_latents(latents_video, latent_grid_size=[2,2], key_grid_size=[3,3], rand_indices=None):
T = latents_video.size(0)
img_h, img_w = latents_video.size(-2) // latent_grid_size[0], latents_video.size(-1) // latent_grid_size[1]
list_of_flattens = [flatten_grid(el.unsqueeze(0), latent_grid_size) for el in latents_video]
long_flatten = torch.cat(list_of_flattens, dim=-1)
keyframe_grid = unflatten_grid(torch.cat([long_flatten[:,:,:,ind*(img_w):(ind+1)*(img_w)] for ind in rand_indices], dim=-1), key_grid_size)
return keyframe_grid, rand_indices
def pil_grid_to_frames(pil_grid, grid_size=[2,2]):
w,h = pil_grid.size
img_w = w // grid_size[1]
img_h = h // grid_size[0]
list_of_pil = []
for i in range(grid_size[0]):
for j in range(grid_size[1]):
list_of_pil.append(pil_grid.crop((j*img_w, i*img_h, (j+1)*img_w, (i+1)*img_h)))
return list_of_pil
if __name__ == '__main__':
a = torch.randint(0,5,(1,3), dtype=torch.float)
|