import datetime import json import os import pickle from functools import partial, wraps from pickle import dump, load from typing import Literal import click import numpy as np import PIL.Image import torch from sklearn.mixture import GaussianMixture from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from torch.utils.data import Subset from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import dnnlib from dataset import ImageFolderDataset from flowutils import PatchFlow, sanitize_locals from networks_edm2 import Precond DEVICE: Literal["cuda", "cpu"] = "cpu" model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions" config_presets = { "edm2-img64-s-fid": f"{model_root}/edm2-img64-s-1073741-0.075.pkl", # fid = 1.58 "edm2-img64-m-fid": f"{model_root}/edm2-img64-m-2147483-0.060.pkl", # fid = 1.43 "edm2-img64-l-fid": f"{model_root}/edm2-img64-l-1073741-0.040.pkl", # fid = 1.33 } class StandardRGBEncoder: def __init__(self): super().__init__() def encode(self, x): # raw pixels => final pixels return x.to(torch.float32) / 127.5 - 1 def decode(self, x): # final latents => raw pixels return (x.to(torch.float32) * 127.5 + 128).clip(0, 255).to(torch.uint8) class EDMScorer(torch.nn.Module): def __init__( self, net, stop_ratio=0.8, # Maximum ratio of noise levels to compute num_steps=10, # Number of noise levels to evaluate. use_fp16=False, # Execute the underlying model at FP16 precision? sigma_min=0.002, # Minimum supported noise level. sigma_max=80, # Maximum supported noise level. sigma_data=0.5, # Expected standard deviation of the training data. rho=7, # Time step discretization. ): super().__init__() self.config = sanitize_locals(locals(), ignore_keys="net") self.config["EDMNet"] = dict(net.init_kwargs) self.use_fp16 = use_fp16 self.sigma_min = sigma_min self.sigma_max = sigma_max self.sigma_data = sigma_data self.net = net.eval() self.encoder = StandardRGBEncoder() # Adjust noise levels based on how far we want to accumulate self.sigma_min = 1e-1 self.sigma_max = sigma_max * stop_ratio step_indices = torch.arange(num_steps, dtype=torch.float64) t_steps = ( self.sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho)) ) ** rho self.register_buffer("sigma_steps", t_steps.to(torch.float64)) @torch.no_grad def forward( self, x, force_fp32=False, ): x = self.encoder.encode(x).to(torch.float32) batch_scores = [] for sigma in self.sigma_steps: xhat = self.net(x, sigma, force_fp32=force_fp32) c_skip = self.net.sigma_data**2 / (sigma**2 + self.net.sigma_data**2) score = xhat - (c_skip * x) batch_scores.append(score) batch_scores = torch.stack(batch_scores, axis=1) return batch_scores class ScoreFlow(torch.nn.Module): def __init__(self, scorenet, device="cpu", **flow_kwargs): super().__init__() h = w = scorenet.net.img_resolution c = scorenet.net.img_channels num_sigmas = len(scorenet.sigma_steps) self.flow = PatchFlow((num_sigmas, c, h, w), **flow_kwargs) self.flow = self.flow.to(device) self.scorenet = scorenet.to(device).eval().requires_grad_(False) self.flow.init_weights() self.config = dict() self.config.update(**self.scorenet.config) self.config.update(self.flow.config) def forward(self, x, **score_kwargs): x_scores = self.scorenet(x, **score_kwargs) return self.flow(x_scores) def build_model_from_config(model_params): net = Precond(**model_params["EDMNet"]) scorenet = EDMScorer(net=net, **model_params["EDMScorer"]) scoreflow = ScoreFlow(scorenet=scorenet, **model_params["PatchFlow"]) print("Built model from config") return scoreflow def build_model_from_pickle(preset="edm2-img64-s-fid", device="cpu"): netpath = config_presets[preset] with dnnlib.util.open_url(netpath, verbose=1) as f: data = pickle.load(f) net = data["ema"] model = EDMScorer(net, num_steps=20).to(device) return model def quantile_scorer(gmm, X, y=None): return np.quantile(gmm.score_samples(X), 0.1) def compute_gmm_likelihood(x_score, gmmdir): with open(f"{gmmdir}/gmm.pkl", "rb") as f: clf = load(f) nll = -clf.score_samples(x_score) with np.load(f"{gmmdir}/refscores.npz", "rb") as f: ref_nll = f["arr_0"] percentile = (ref_nll < nll).mean() return nll, percentile @torch.inference_mode def test_runner(device="cpu"): # f = "doge.jpg" f = "goldfish.JPEG" image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS) image = np.array(image) image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) x = torch.from_numpy(image).unsqueeze(0).to(device) model = build_model_from_pickle(device=device) scores = model(x) return scores def test_flow_runner(preset, device="cpu", load_weights=None): # f = "doge.jpg" f = "goldfish.JPEG" image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS) image = np.array(image) image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) x = torch.from_numpy(image).unsqueeze(0).to(device) scorenet = build_model_from_pickle(preset) score_flow = ScoreFlow(scorenet, device=device) if load_weights is not None: score_flow.flow.load_state_dict(torch.load(load_weights)) heatmap = score_flow(x) print(heatmap.shape) heatmap = score_flow(x).detach().cpu().numpy() heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) * 255 im = PIL.Image.fromarray(heatmap[0, 0]) im.convert("RGB").save( "heatmap.png", ) return @click.group() def cmdline(): global DEVICE DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def common_args(func): @wraps(func) @click.option( "--preset", help="Configuration preset", metavar="STR", type=str, default="edm2-img64-s-fid", show_default=True, ) @click.option( "--dataset_path", help="Path to the dataset", metavar="ZIP|DIR", type=str, default=None, ) @click.option( "--outdir", help="Where to load/save the results", metavar="DIR", type=str, required=True, ) def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper @cmdline.command("train-gmm") @click.option( "--gridsearch", help="Whether to use a grid search on a number of components to find the best fit", is_flag=True, default=False, ) @common_args def train_gmm(preset, outdir, gridsearch=False, **kwargs): outdir = f"{outdir}/{preset}" score_path = f"{outdir}/imagenette_score_norms.pt" X = torch.load(score_path).numpy() print(f"Loaded score norms from: {score_path} - # Samples: {X.shape[0]}") gm = GaussianMixture( n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000 ) clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)]) if gridsearch: param_grid = dict( GMM__n_components=range(2, 11, 1), ) grid = GridSearchCV( estimator=clf, param_grid=param_grid, cv=5, n_jobs=2, verbose=1, scoring=quantile_scorer, ) grid_result = grid.fit(X) print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_)) print("-----" * 15) means = grid_result.cv_results_["mean_test_score"] stds = grid_result.cv_results_["std_test_score"] params = grid_result.cv_results_["params"] for mean, stdev, param in zip(means, stds, params): print("%f (%f) with: %r" % (mean, stdev, param)) clf = grid.best_estimator_ clf.fit(X) inlier_nll = -clf.score_samples(X) print("Saving reference inlier scores ... ") os.makedirs(outdir, exist_ok=True) with open(f"{outdir}/refscores.npz", "wb") as f: np.savez_compressed(f, inlier_nll) with open(f"{outdir}/gmm.pkl", "wb") as f: dump(clf, f, protocol=5) print("Saved GMM pickle.") @cmdline.command(name="cache-scores") @click.option( "--batch_size", help="Number of samples per batch", metavar="INT", type=int, default=64, show_default=True, ) @common_args def cache_score_norms(preset, dataset_path, outdir, batch_size): device = DEVICE dsobj = ImageFolderDataset(path=dataset_path, resolution=64) refimg, reflabel = dsobj[0] print(f"Loading dataset from {dataset_path}") print( f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}" ) dsloader = torch.utils.data.DataLoader( dsobj, batch_size=batch_size, num_workers=4, prefetch_factor=2 ) model = build_model_from_pickle(preset=preset, device=device) score_norms = [] for x, _ in tqdm(dsloader): s = model(x.to(device)) s = s.square().sum(dim=(2, 3, 4)) ** 0.5 score_norms.append(s.cpu()) score_norms = torch.cat(score_norms, dim=0) os.makedirs(f"{outdir}/{preset}/", exist_ok=True) with open(f"{outdir}/{preset}/imagenette_score_norms.pt", "wb") as f: torch.save(score_norms, f) print(f"Computed score norms for {score_norms.shape[0]} samples") @cmdline.command(name="train-flow") @click.option( "--epochs", help="Number of epochs", metavar="INT", type=int, default=10, show_default=True, ) @click.option( "--num_flows", help="Number of normalizing flow functions in the PatchFlow model", metavar="INT", type=int, default=4, show_default=True, ) @click.option( "--batch_size", help="Number of samples per batch", metavar="INT", type=int, default=128, show_default=True, ) @common_args def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs): print("using device:", DEVICE) device = DEVICE dsobj = ImageFolderDataset(path=dataset_path, resolution=64) print(f"Loaded {len(dsobj)} samples from {dataset_path}") # Subset of training dataset val_ratio = 0.1 train_len = int((1 - val_ratio) * len(dsobj)) val_len = len(dsobj) - train_len print( f"Generating train/test split with ratio={val_ratio} -> {train_len}/{val_len}..." ) train_ds = Subset(dsobj, range(train_len)) val_ds = Subset(dsobj, range(train_len, train_len + val_len)) trainiter = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True ) testiter = torch.utils.data.DataLoader( val_ds, batch_size=batch_size * 2, num_workers=4, prefetch_factor=2 ) scorenet = build_model_from_pickle(preset) model = ScoreFlow(scorenet, device=device, **flow_kwargs) opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5) train_step = partial( PatchFlow.stochastic_step, flow_model=model.flow, opt=opt, train=True, n_patches=128, device=device, ) eval_step = partial( PatchFlow.stochastic_step, flow_model=model.flow, train=False, n_patches=256, device=device, ) experiment_dir = f"{outdir}/{preset}" os.makedirs(experiment_dir, exist_ok=True) timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M") writer = SummaryWriter(f"{experiment_dir}/logs/{timestamp}") with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f: json.dump(model.config, f, sort_keys=True, indent=4) with open(f"{experiment_dir}/config.json", "w") as f: json.dump(model.config, f, sort_keys=True, indent=4) # totaliters = int(epochs * train_len) pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?") step = 0 for e in pbar: for x, _ in trainiter: x = x.to(device) scores = model.scorenet(x) if step == 0: with torch.inference_mode(): val_loss = eval_step(scores, x) # Log details about model writer.add_graph( model.flow.flows, ( torch.zeros(1, scores.shape[1], device=device), torch.zeros( 1, model.flow.position_encoding.cached_penc.shape[-1], device=device, ), ), ) train_loss = train_step(scores, x) if (step + 1) % 10 == 0: prev_val_loss = val_loss val_loss = 0.0 with torch.inference_mode(): for i, (x, _) in enumerate(testiter): x = x.to(device) scores = model.scorenet(x) val_loss += eval_step(scores, x) break val_loss /= i + 1 writer.add_scalar("loss/val", train_loss, step) if val_loss < prev_val_loss: torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt") writer.add_scalar("loss/train", train_loss, step) pbar.set_description( f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}" ) step += 1 # Squeeze the juice best_ckpt = torch.load(f"{experiment_dir}/flow.pt") model.flow.load_state_dict(best_ckpt) pbar = tqdm(range(10), desc="(Tuning) Step:? - Loss: ?") for e in pbar: for x, _ in testiter: x = x.to(device) scores = model.scorenet(x) train_loss = train_step(scores, x) writer.add_scalar("loss/train", train_loss, step) pbar.set_description(f"(Tuning) Step: {step:d} - Loss: {train_loss:.3f}") step += 1 # Save final model torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt") writer.close() # cache_score_norms( # preset=preset, # dataset_path="/GROND_STOR/amahmood/datasets/img64/", # device="cuda", # ) # train_gmm( # f"out/msma/{preset}_imagenette_score_norms.pt", outdir=f"out/msma/{preset}" # ) # s = test_runner(device=device) # s = s.square().sum(dim=(2, 3, 4)) ** 0.5 # s = s.to("cpu").numpy() # nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/") # print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile") if __name__ == "__main__": cmdline()