Spaces:
Running
Running
from fastai.vision.models.unet import DynamicUnet | |
from torchvision.models.resnet import resnet18 | |
from fastai.vision.models import resnet18 | |
from fastai.vision.learner import create_body | |
import streamlit as st | |
from PIL import Image | |
import cv2 as cv | |
import os | |
import glob | |
import time | |
import numpy as np | |
from PIL import Image | |
from pathlib import Path | |
from tqdm.notebook import tqdm | |
import matplotlib.pyplot as plt | |
from skimage.color import rgb2lab, lab2rgb | |
# pip install fastai==2.4 | |
import torch | |
from torch import nn, optim | |
from torchvision import transforms | |
from torchvision.utils import make_grid | |
from torch.utils.data import Dataset, DataLoader | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
use_colab = None | |
SIZE = 256 | |
class ColorizationDataset(Dataset): | |
def __init__(self, paths, split='train'): | |
if split == 'train': | |
self.transforms = transforms.Compose([ | |
transforms.Resize((SIZE, SIZE), Image.BICUBIC), | |
transforms.RandomHorizontalFlip(), | |
]) | |
elif split == 'val': | |
self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC) | |
self.split = split | |
self.size = SIZE | |
self.paths = paths | |
def __getitem__(self, idx): | |
img = Image.open(self.paths[idx]).convert("RGB") | |
img = self.transforms(img) | |
img = np.array(img) | |
img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b | |
img_lab = transforms.ToTensor()(img_lab) | |
L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1 | |
ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1 | |
return {'L': L, 'ab': ab} | |
def __len__(self): | |
return len(self.paths) | |
def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): | |
dataset = ColorizationDataset(**kwargs) | |
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, | |
pin_memory=pin_memory) | |
return dataloader | |
class UnetBlock(nn.Module): | |
def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False, | |
innermost=False, outermost=False): | |
super().__init__() | |
self.outermost = outermost | |
if input_c is None: | |
input_c = nf | |
downconv = nn.Conv2d(input_c, ni, kernel_size=4, | |
stride=2, padding=1, bias=False) | |
downrelu = nn.LeakyReLU(0.2, True) | |
downnorm = nn.BatchNorm2d(ni) | |
uprelu = nn.ReLU(True) | |
upnorm = nn.BatchNorm2d(nf) | |
if outermost: | |
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, | |
stride=2, padding=1) | |
down = [downconv] | |
up = [uprelu, upconv, nn.Tanh()] | |
model = down + [submodule] + up | |
elif innermost: | |
upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4, | |
stride=2, padding=1, bias=False) | |
down = [downrelu, downconv] | |
up = [uprelu, upconv, upnorm] | |
model = down + up | |
else: | |
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, | |
stride=2, padding=1, bias=False) | |
down = [downrelu, downconv, downnorm] | |
up = [uprelu, upconv, upnorm] | |
if dropout: | |
up += [nn.Dropout(0.5)] | |
model = down + [submodule] + up | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
if self.outermost: | |
return self.model(x) | |
else: | |
return torch.cat([x, self.model(x)], 1) | |
class Unet(nn.Module): | |
def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64): | |
super().__init__() | |
unet_block = UnetBlock( | |
num_filters * 8, num_filters * 8, innermost=True) | |
for _ in range(n_down - 5): | |
unet_block = UnetBlock( | |
num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True) | |
out_filters = num_filters * 8 | |
for _ in range(3): | |
unet_block = UnetBlock( | |
out_filters // 2, out_filters, submodule=unet_block) | |
out_filters //= 2 | |
self.model = UnetBlock( | |
output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True) | |
def forward(self, x): | |
return self.model(x) | |
class PatchDiscriminator(nn.Module): | |
def __init__(self, input_c, num_filters=64, n_down=3): | |
super().__init__() | |
model = [self.get_layers(input_c, num_filters, norm=False)] | |
model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) | |
for i in range(n_down)] # the 'if' statement is taking care of not using | |
# stride of 2 for the last block in this loop | |
# Make sure to not use normalization or | |
model += [self.get_layers(num_filters * 2 ** | |
n_down, 1, s=1, norm=False, act=False)] | |
# activation for the last layer of the model | |
self.model = nn.Sequential(*model) | |
def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): | |
layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] | |
if norm: | |
layers += [nn.BatchNorm2d(nf)] | |
if act: | |
layers += [nn.LeakyReLU(0.2, True)] | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |
class GANLoss(nn.Module): | |
def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0): | |
super().__init__() | |
self.register_buffer('real_label', torch.tensor(real_label)) | |
self.register_buffer('fake_label', torch.tensor(fake_label)) | |
if gan_mode == 'vanilla': | |
self.loss = nn.BCEWithLogitsLoss() | |
elif gan_mode == 'lsgan': | |
self.loss = nn.MSELoss() | |
def get_labels(self, preds, target_is_real): | |
if target_is_real: | |
labels = self.real_label | |
else: | |
labels = self.fake_label | |
return labels.expand_as(preds) | |
def __call__(self, preds, target_is_real): | |
labels = self.get_labels(preds, target_is_real) | |
loss = self.loss(preds, labels) | |
return loss | |
def init_weights(net, init='norm', gain=0.02): | |
def init_func(m): | |
classname = m.__class__.__name__ | |
if hasattr(m, 'weight') and 'Conv' in classname: | |
if init == 'norm': | |
nn.init.normal_(m.weight.data, mean=0.0, std=gain) | |
elif init == 'xavier': | |
nn.init.xavier_normal_(m.weight.data, gain=gain) | |
elif init == 'kaiming': | |
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
if hasattr(m, 'bias') and m.bias is not None: | |
nn.init.constant_(m.bias.data, 0.0) | |
elif 'BatchNorm2d' in classname: | |
nn.init.normal_(m.weight.data, 1., gain) | |
nn.init.constant_(m.bias.data, 0.) | |
net.apply(init_func) | |
print(f"model initialized with {init} initialization") | |
return net | |
def init_model(model, device): | |
model = model.to(device) | |
model = init_weights(model) | |
return model | |
class MainModel(nn.Module): | |
def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, | |
beta1=0.5, beta2=0.999, lambda_L1=100.): | |
super().__init__() | |
self.device = torch.device( | |
"cuda" if torch.cuda.is_available() else "cpu") | |
self.lambda_L1 = lambda_L1 | |
if net_G is None: | |
self.net_G = init_model( | |
Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device) | |
else: | |
self.net_G = net_G.to(self.device) | |
self.net_D = init_model(PatchDiscriminator( | |
input_c=3, n_down=3, num_filters=64), self.device) | |
self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device) | |
self.L1criterion = nn.L1Loss() | |
self.opt_G = optim.Adam(self.net_G.parameters(), | |
lr=lr_G, betas=(beta1, beta2)) | |
self.opt_D = optim.Adam(self.net_D.parameters(), | |
lr=lr_D, betas=(beta1, beta2)) | |
def set_requires_grad(self, model, requires_grad=True): | |
for p in model.parameters(): | |
p.requires_grad = requires_grad | |
def setup_input(self, data): | |
self.L = data['L'].to(self.device) | |
self.ab = data['ab'].to(self.device) | |
def forward(self): | |
self.fake_color = self.net_G(self.L) | |
def backward_D(self): | |
fake_image = torch.cat([self.L, self.fake_color], dim=1) | |
fake_preds = self.net_D(fake_image.detach()) | |
self.loss_D_fake = self.GANcriterion(fake_preds, False) | |
real_image = torch.cat([self.L, self.ab], dim=1) | |
real_preds = self.net_D(real_image) | |
self.loss_D_real = self.GANcriterion(real_preds, True) | |
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 | |
self.loss_D.backward() | |
def backward_G(self): | |
fake_image = torch.cat([self.L, self.fake_color], dim=1) | |
fake_preds = self.net_D(fake_image) | |
self.loss_G_GAN = self.GANcriterion(fake_preds, True) | |
self.loss_G_L1 = self.L1criterion( | |
self.fake_color, self.ab) * self.lambda_L1 | |
self.loss_G = self.loss_G_GAN + self.loss_G_L1 | |
self.loss_G.backward() | |
def optimize(self): | |
self.forward() | |
self.net_D.train() | |
self.set_requires_grad(self.net_D, True) | |
self.opt_D.zero_grad() | |
self.backward_D() | |
self.opt_D.step() | |
self.net_G.train() | |
self.set_requires_grad(self.net_D, False) | |
self.opt_G.zero_grad() | |
self.backward_G() | |
self.opt_G.step() | |
class AverageMeter: | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.count, self.avg, self.sum = [0.] * 3 | |
def update(self, val, count=1): | |
self.count += count | |
self.sum += count * val | |
self.avg = self.sum / self.count | |
def create_loss_meters(): | |
loss_D_fake = AverageMeter() | |
loss_D_real = AverageMeter() | |
loss_D = AverageMeter() | |
loss_G_GAN = AverageMeter() | |
loss_G_L1 = AverageMeter() | |
loss_G = AverageMeter() | |
return {'loss_D_fake': loss_D_fake, | |
'loss_D_real': loss_D_real, | |
'loss_D': loss_D, | |
'loss_G_GAN': loss_G_GAN, | |
'loss_G_L1': loss_G_L1, | |
'loss_G': loss_G} | |
def update_losses(model, loss_meter_dict, count): | |
for loss_name, loss_meter in loss_meter_dict.items(): | |
loss = getattr(model, loss_name) | |
loss_meter.update(loss.item(), count=count) | |
def lab_to_rgb(L, ab): | |
""" | |
Takes a batch of images | |
""" | |
L = (L + 1.) * 50. | |
ab = ab * 110. | |
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() | |
rgb_imgs = [] | |
for img in Lab: | |
img_rgb = lab2rgb(img) | |
rgb_imgs.append(img_rgb) | |
return np.stack(rgb_imgs, axis=0) | |
def visualize(model, data, dims): | |
model.net_G.eval() | |
with torch.no_grad(): | |
model.setup_input(data) | |
model.forward() | |
model.net_G.train() | |
fake_color = model.fake_color.detach() | |
real_color = model.ab | |
L = model.L | |
fake_imgs = lab_to_rgb(L, fake_color) | |
real_imgs = lab_to_rgb(L, real_color) | |
for i in range(1): | |
# t_img = transforms.Resize((dims[0], dims[1]))(t_img) | |
img = Image.fromarray(np.uint8(fake_imgs[i])) | |
img = cv.resize(fake_imgs[i], dsize=( | |
dims[1], dims[0]), interpolation=cv.INTER_CUBIC) | |
# st.text(f"Size of fake image {fake_imgs[i].shape} \n Type of image = {type(fake_imgs[i])}") | |
st.image(img, caption="Output image", | |
use_column_width='auto', clamp=True) | |
def log_results(loss_meter_dict): | |
for loss_name, loss_meter in loss_meter_dict.items(): | |
print(f"{loss_name}: {loss_meter.avg:.5f}") | |
# pip install fastai==2.4 | |
from fastai.vision.learner import create_body | |
from torchvision.models.resnet import resnet18 | |
from fastai.vision.models.unet import DynamicUnet | |
def build_res_unet(n_input=1, n_output=2, size=256): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
body = create_body(resnet18(pretrained=True), n_in=n_input, cut=-2) | |
net_G = DynamicUnet(body, n_output, (size, size)).to(device) | |
return net_G | |
net_G = build_res_unet(n_input=1, n_output=2, size=256) | |
net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device)) | |
model = MainModel(net_G=net_G) | |
model.load_state_dict(torch.load("main-model.pt", map_location=device)) | |
class MyDataset(torch.utils.data.Dataset): | |
def __init__(self, img_list): | |
super(MyDataset, self).__init__() | |
self.img_list = img_list | |
self.augmentations = transforms.Resize((SIZE, SIZE), Image.BICUBIC) | |
def __len__(self): | |
return len(self.img_list) | |
def __getitem__(self, idx): | |
img = self.img_list[idx] | |
img = self.augmentations(img) | |
img = np.array(img) | |
img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b | |
img_lab = transforms.ToTensor()(img_lab) | |
L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1 | |
ab = img_lab[[1, 2], ...] / 110. | |
return {'L': L, 'ab': ab} | |
def make_dataloaders2(batch_size=16, n_workers=4, pin_memory=True, **kwargs): | |
dataset = MyDataset(**kwargs) | |
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, | |
pin_memory=pin_memory) | |
return dataloader | |
# st.set_option('deprecation.showfileUploaderEncoding', False) | |
# @st.cache(allow_output_mutation= True) | |
st.write(""" | |
# Image Recolorisation | |
""" | |
) | |
st.subheader("Created by Pushkar") | |
file_up = st.file_uploader("Upload an jpg image", type=["jpg", "jpeg", "png"]) | |
if file_up is not None: | |
im = Image.open(file_up) | |
st.text(body=f"Size of uploaded image {im.shape}") | |
a = im.shape | |
st.image(im, caption="Uploaded Image.", use_column_width='auto') | |
test_dl = make_dataloaders2(img_list=[im]) | |
for data in test_dl: | |
model.setup_input(data) | |
model.optimize() | |
visualize(model, data, a) | |