dinov2-pca / app.py
RRoundTable
Add app for dinov2 pca
26be1cc
raw
history blame
2.4 kB
import torch
import torch.nn as nn
import cv2
import gradio as gr
import glob
from typing import List
import torch.nn.functional as F
import torchvision.transforms as T
from sklearn.decomposition import PCA
import sklearn
import numpy as np
# Constants
patch_h = 40
patch_w = 40
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# DINOV2
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
# Trasnforms
transform = T.Compose([
T.Resize((patch_h * 14, patch_w * 14)),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
# Empty Tenosr
imgs_tensor = torch.zeros(4, 3, patch_h * 14, patch_w * 14)
# PCA
pca = PCA(n_components=3)
def query_image(img1, img2, img3, img4) -> List[np.ndarray]:
# Transform
imgs = [img1, img2, img3, img4]
for i, img in enumerate(imgs):
img = np.transpose(img, (2, 0, 1))
imgs_tensor[i] = transform(torch.Tensor(img))
# Get feature from patches
with torch.no_grad():
features_dict = model.forward_features(imgs_tensor)
features = features_dict['x_prenorm'][:, 1:]
features = features.reshape(4 * patch_h * patch_w, -1)
# PCA Feature
pca.fit(features)
pca_features = pca.transform(features)
pca_feature = sklearn.preprocessing.minmax_scale(pca_features)
# Foreground/Background
pca_features_bg = pca_features[:, 0] < 0
pca_features_fg = ~pca_features_bg
# PCA with only foreground
pca.fit(features[pca_features_fg])
pca_features_rem = pca.transform(features[pca_features_fg])
# Min Max Normalization
for i in range(3):
pca_features_rem[:, i] = (pca_features_rem[:, i] - pca_features_rem[:, i].min()) / (pca_features_rem[:, i].max() - pca_features_rem[:, i].min())
pca_features_rgb = np.zeros((4 * patch_h * patch_w, 3))
pca_features_rgb[pca_features_bg] = 0
pca_features_rgb[pca_features_fg] = pca_features_rem
pca_features_rgb = pca_features_rgb.reshape(4, patch_h, patch_w, 3)
return [pca_features_rgb[i] for i in range(4)]
description = """
DINOV2 PCA
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image()],
outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image()],
title="DINOV2 PCA",
description=description,
examples=[],
)
demo.launch()