StyleTransfer / model.py
SKT27182's picture
Added multiple models
280a2a5
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from keras import backend as K
class NeuralStyleTransfer:
def __init__(self, style_image, content_image, extractor, n_style_layers=5, n_content_layers=5):
# load the model
if extractor == "inception_v3":
self.feature_extractor = tf.keras.applications.InceptionV3(
include_top=False, weights="imagenet"
)
elif extractor == "vgg19":
self.feature_extractor = tf.keras.applications.VGG19(
include_top=False, weights="imagenet"
)
elif extractor == "resnet50":
self.feature_extractor = tf.keras.applications.ResNet50(
include_top=False, weights="imagenet"
)
elif extractor == "mobilenet_v2":
self.feature_extractor = tf.keras.applications.MobileNetV2(
include_top=False, weights="imagenet"
)
elif isinstance(extractor, tf.keras.Model):
self.feature_extractor = extractor
else:
raise Exception("Features Extractor not found")
# freeze the model
self.feature_extractor.trainable = False
# define the style and content depth
self.n_style_layers = n_style_layers
self.n_content_layers = n_content_layers
self.style_image = self._load_img(style_image)
self.content_image = self._load_img(content_image)
def tensor_to_image(self, tensor):
"""converts a tensor to an image"""
tensor_shape = tf.shape(tensor)
number_elem_shape = tf.shape(tensor_shape)
if number_elem_shape > 3:
assert tensor_shape[0] == 1
tensor = tensor[0]
return tf.keras.preprocessing.image.array_to_img(tensor)
def _load_img(self, image):
max_dim = 512
image = tf.io.read_file(image)
image = tf.image.decode_image(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.convert_image_dtype(image, tf.float32)
shape = tf.shape(image)[:-1]
shape = tf.cast(tf.shape(image)[:-1], tf.float32)
long_dim = max(shape)
scale = max_dim / long_dim
new_shape = tf.cast(shape * scale, tf.int32)
image = tf.image.resize(image, new_shape)
image = image[tf.newaxis, :]
image = tf.image.convert_image_dtype(image, tf.uint8)
return image
def imshow(self, image, title=None):
"""displays an image with a corresponding title"""
if len(image.shape) > 3:
image = tf.squeeze(image, axis=0)
plt.imshow(image)
if title:
plt.title(title)
def show_images_with_objects(self, images, titles=[]):
"""displays a row of images with corresponding titles"""
if len(images) != len(titles):
return
plt.figure(figsize=(20, 12))
for idx, (image, title) in enumerate(zip(images, titles)):
plt.subplot(1, len(images), idx + 1)
plt.xticks([])
plt.yticks([])
self.imshow(image, title)
def _preprocess_image(self, image):
image = tf.cast(image, dtype=tf.float32)
image = (image / 127.5) - 1.0
return image
def get_output_layers(self):
# get all the layers which contain conv in their name
all_layers = [
layer.name
for layer in self.feature_extractor.layers
if "conv" in layer.name
]
# define the style layers
style_layers = all_layers[: self.n_style_layers]
# define the content layers from second last layer
content_layers = all_layers[-2: -self.n_content_layers - 2 : -1]
content_and_style_layers = content_layers + style_layers
return content_and_style_layers
def build(self, layers_name):
output_layers = [
self.feature_extractor.get_layer(name).output for name in layers_name
]
model = tf.keras.Model(self.feature_extractor.input, output_layers)
self.feature_extractor = model
return
def _loss(self, target_img, features_img, type):
"""
Calculates the loss of the style transfer
target_img:
the target image (style or content) features
features_img:
the generated image features (style or content)
"""
loss = tf.reduce_mean(tf.square(features_img - target_img))
if type == "content":
return 0.5 * loss
return loss
def _gram_matrix(self, input_tensor):
"""
Calculates the gram matrix and divides by the number of locations
input_tensor:
the output of the conv layer of the style image, shape = (batch_size, height, width, channels)
"""
result = tf.linalg.einsum("bijc,bijd->bcd", input_tensor, input_tensor)
input_shape = tf.shape(input_tensor)
num_locations = tf.cast(input_shape[1] * input_shape[2], tf.float32)
return result / (num_locations)
def get_features(self, image, type):
preprocess_image = self._preprocess_image(image)
outputs = self.feature_extractor(preprocess_image)
if type == "style":
outputs = outputs[self.n_content_layers : ]
features = [self._gram_matrix(style_output) for style_output in outputs]
elif type == "content":
features = outputs[ : self.n_content_layers]
return features
def _style_content_loss(
self,
style_targets,
style_outputs,
content_targets,
content_outputs,
style_weight,
content_weight,
):
"""
Calculates the total loss of the style transfer
style_targets:
the style features of the style image
style_outputs:
the style features of the generated image
content_targets:
the content features of the content image
content_outputs:
the content features of the generated image
style_weight:
the weight of the style loss
content_weight:
the weight of the content loss
"""
# adding the loss of each layer
style_loss = style_weight * tf.add_n(
[
self._loss(style_target, style_output, type="style")
for style_target, style_output in zip(style_targets, style_outputs)
]
)
content_loss = content_weight * tf.add_n(
[
self._loss(content_target, content_output, type="content")
for content_target, content_output in zip(
content_targets, content_outputs
)
]
)
total_loss = style_loss + content_loss
return total_loss, style_loss, content_loss
def _grad_loss(
self,
generated_image,
style_target,
content_target,
style_weight,
content_weight,
var_weight,
):
"""
Calculates the gradients of the loss function with respect to the generated image
generated_image:
the generated image
"""
with tf.GradientTape() as tape:
style_features = self.get_features(generated_image, type="style")
content_features = self.get_features(generated_image, type="content")
loss, style_loss, content_loss = self._style_content_loss(
style_target,
style_features,
content_target,
content_features,
style_weight,
content_weight,
)
variational_loss= var_weight*tf.image.total_variation(generated_image)
loss += variational_loss
grads = tape.gradient(loss, generated_image)
return grads, loss, [style_loss, content_loss, variational_loss]
def _update_image_with_style(
self,
generated_image,
style_target,
content_target,
style_weight,
content_weight,
optimizer,
var_weight,
):
grads, loss, loss_list = self._grad_loss(
generated_image, style_target, content_target, style_weight, content_weight, var_weight
)
optimizer.apply_gradients([(grads, generated_image)])
generated_image.assign(
tf.clip_by_value(generated_image, clip_value_min=0.0, clip_value_max=255.0)
)
return loss_list