Spaces:
Build error
Build error
File size: 3,148 Bytes
8f87579 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import argparse
import pickle
import torch
from torch import nn
import numpy as np
from scipy import linalg
from tqdm import tqdm
from model import Generator
from calc_inception import load_patched_inception_v3
@torch.no_grad()
def extract_feature_from_samples(
generator, inception, truncation, truncation_latent, batch_size, n_sample, device
):
n_batch = n_sample // batch_size
resid = n_sample - (n_batch * batch_size)
batch_sizes = [batch_size] * n_batch + [resid]
features = []
for batch in tqdm(batch_sizes):
latent = torch.randn(batch, 512, device=device)
img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
feat = inception(img)[0].view(img.shape[0], -1)
features.append(feat.to('cpu'))
features = torch.cat(features, 0)
return features
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
if not np.isfinite(cov_sqrt).all():
print('product of cov matrices is singular')
offset = np.eye(sample_cov.shape[0]) * eps
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
if np.iscomplexobj(cov_sqrt):
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
m = np.max(np.abs(cov_sqrt.imag))
raise ValueError(f'Imaginary component {m}')
cov_sqrt = cov_sqrt.real
mean_diff = sample_mean - real_mean
mean_norm = mean_diff @ mean_diff
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
fid = mean_norm + trace
return fid
if __name__ == '__main__':
device = 'cuda'
parser = argparse.ArgumentParser()
parser.add_argument('--truncation', type=float, default=1)
parser.add_argument('--truncation_mean', type=int, default=4096)
parser.add_argument('--batch', type=int, default=64)
parser.add_argument('--n_sample', type=int, default=50000)
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--inception', type=str, default=None, required=True)
parser.add_argument('ckpt', metavar='CHECKPOINT')
args = parser.parse_args()
ckpt = torch.load(args.ckpt)
g = Generator(args.size, 512, 8).to(device)
g.load_state_dict(ckpt['g_ema'])
g = nn.DataParallel(g)
g.eval()
if args.truncation < 1:
with torch.no_grad():
mean_latent = g.mean_latent(args.truncation_mean)
else:
mean_latent = None
inception = nn.DataParallel(load_patched_inception_v3()).to(device)
inception.eval()
features = extract_feature_from_samples(
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
).numpy()
print(f'extracted {features.shape[0]} features')
sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)
with open(args.inception, 'rb') as f:
embeds = pickle.load(f)
real_mean = embeds['mean']
real_cov = embeds['cov']
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
print('fid:', fid)
|