Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,222 Bytes
7f2690b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
from glob import glob
import joblib
import numpy as np
import torch
from sklearn.cluster import MiniBatchKMeans
from torch.utils.data import DataLoader
from tqdm import tqdm
from train import instantiate_from_config
class FeatClusterStage(object):
def __init__(self, num_clusters=None, cached_kmeans_path=None, feats_dataset_config=None, num_workers=None):
if cached_kmeans_path is not None and os.path.exists(cached_kmeans_path):
print(f'Precalculated Clusterer already exists, loading from {cached_kmeans_path}')
self.clusterer = joblib.load(cached_kmeans_path)
elif feats_dataset_config is not None:
self.clusterer = self.load_or_precalculate_kmeans(num_clusters, feats_dataset_config, num_workers)
else:
raise Exception('Neither `feats_dataset_config` nor `cached_kmeans_path` are defined')
def eval(self):
return self
def encode(self, c):
# c_quant: cluster centers, c_ind: cluster index
B, D, T = c.shape
# (B*T, D) <- (B, T, D) <- (B, D, T)
c_flat = c.permute(0, 2, 1).view(B*T, D).cpu().numpy()
c_ind = self.clusterer.predict(c_flat)
c_quant = self.clusterer.cluster_centers_[c_ind]
c_ind = torch.from_numpy(c_ind).to(c.device)
c_quant = torch.from_numpy(c_quant).to(c.device)
c_ind = c_ind.long().unsqueeze(-1)
c_quant = c_quant.view(B, T, D).permute(0, 2, 1)
info = None, None, c_ind
# (B, D, T), (), ((), (768, 1024), (768, 1))
return c_quant, None, info
def decode(self, c):
return c
def get_input(self, batch, k):
x = batch[k]
x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
return x.float()
def load_or_precalculate_kmeans(self, num_clusters, dataset_cfg, num_workers):
print(f'Calculating clustering K={num_clusters}')
batch_size = 64
dataset_name = dataset_cfg.target.split('.')[-1]
cached_path = os.path.join('./specvqgan/modules/misc/', f'kmeans_K{num_clusters}_{dataset_name}.sklearn')
feat_depth = dataset_cfg.params.condition_dataset_cfg.feat_depth
feat_crop_len = dataset_cfg.params.condition_dataset_cfg.feat_crop_len
feat_loading_dset = instantiate_from_config(dataset_cfg)
feat_loading_dset = DataLoader(feat_loading_dset, batch_size, num_workers=num_workers, shuffle=True)
clusterer = MiniBatchKMeans(num_clusters, batch_size=batch_size*feat_crop_len, random_state=0)
for item in tqdm(feat_loading_dset):
batch = item['feature'].reshape(-1, feat_depth).float().numpy()
clusterer.partial_fit(batch)
joblib.dump(clusterer, cached_path)
print(f'Saved the calculated Clusterer @ {cached_path}')
return clusterer
if __name__ == '__main__':
from omegaconf import OmegaConf
config = OmegaConf.load('./configs/vggsound_featcluster_transformer.yaml')
config.model.params.first_stage_config.params.ckpt_path = './logs/2021-05-19T22-16-54_vggsound_specs_vqgan/checkpoints/epoch_39.ckpt'
model = instantiate_from_config(config.model.params.cond_stage_config)
print(model)
|