Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from keras import backend as K | |
class NeuralStyleTransfer: | |
def __init__(self, style_image, content_image, extractor, n_style_layers=5, n_content_layers=5): | |
# load the model | |
if extractor == "inception_v3": | |
self.feature_extractor = tf.keras.applications.InceptionV3( | |
include_top=False, weights="imagenet" | |
) | |
elif extractor == "vgg19": | |
self.feature_extractor = tf.keras.applications.VGG19( | |
include_top=False, weights="imagenet" | |
) | |
elif extractor == "resnet50": | |
self.feature_extractor = tf.keras.applications.ResNet50( | |
include_top=False, weights="imagenet" | |
) | |
elif extractor == "mobilenet_v2": | |
self.feature_extractor = tf.keras.applications.MobileNetV2( | |
include_top=False, weights="imagenet" | |
) | |
elif isinstance(extractor, tf.keras.Model): | |
self.feature_extractor = extractor | |
else: | |
raise Exception("Features Extractor not found") | |
# freeze the model | |
self.feature_extractor.trainable = False | |
# define the style and content depth | |
self.n_style_layers = n_style_layers | |
self.n_content_layers = n_content_layers | |
self.style_image = self._load_img(style_image) | |
self.content_image = self._load_img(content_image) | |
def tensor_to_image(self, tensor): | |
"""converts a tensor to an image""" | |
tensor_shape = tf.shape(tensor) | |
number_elem_shape = tf.shape(tensor_shape) | |
if number_elem_shape > 3: | |
assert tensor_shape[0] == 1 | |
tensor = tensor[0] | |
return tf.keras.preprocessing.image.array_to_img(tensor) | |
def _load_img(self, image): | |
max_dim = 512 | |
image = tf.io.read_file(image) | |
image = tf.image.decode_image(image) | |
image = tf.image.convert_image_dtype(image, tf.float32) | |
image = tf.image.convert_image_dtype(image, tf.float32) | |
shape = tf.shape(image)[:-1] | |
shape = tf.cast(tf.shape(image)[:-1], tf.float32) | |
long_dim = max(shape) | |
scale = max_dim / long_dim | |
new_shape = tf.cast(shape * scale, tf.int32) | |
image = tf.image.resize(image, new_shape) | |
image = image[tf.newaxis, :] | |
image = tf.image.convert_image_dtype(image, tf.uint8) | |
return image | |
def imshow(self, image, title=None): | |
"""displays an image with a corresponding title""" | |
if len(image.shape) > 3: | |
image = tf.squeeze(image, axis=0) | |
plt.imshow(image) | |
if title: | |
plt.title(title) | |
def show_images_with_objects(self, images, titles=[]): | |
"""displays a row of images with corresponding titles""" | |
if len(images) != len(titles): | |
return | |
plt.figure(figsize=(20, 12)) | |
for idx, (image, title) in enumerate(zip(images, titles)): | |
plt.subplot(1, len(images), idx + 1) | |
plt.xticks([]) | |
plt.yticks([]) | |
self.imshow(image, title) | |
def _preprocess_image(self, image): | |
image = tf.cast(image, dtype=tf.float32) | |
image = (image / 127.5) - 1.0 | |
return image | |
def get_output_layers(self): | |
# get all the layers which contain conv in their name | |
all_layers = [ | |
layer.name | |
for layer in self.feature_extractor.layers | |
if "conv" in layer.name | |
] | |
# define the style layers | |
style_layers = all_layers[: self.n_style_layers] | |
# define the content layers from second last layer | |
content_layers = all_layers[-2: -self.n_content_layers - 2 : -1] | |
content_and_style_layers = content_layers + style_layers | |
return content_and_style_layers | |
def build(self, layers_name): | |
output_layers = [ | |
self.feature_extractor.get_layer(name).output for name in layers_name | |
] | |
model = tf.keras.Model(self.feature_extractor.input, output_layers) | |
self.feature_extractor = model | |
return | |
def _loss(self, target_img, features_img, type): | |
""" | |
Calculates the loss of the style transfer | |
target_img: | |
the target image (style or content) features | |
features_img: | |
the generated image features (style or content) | |
""" | |
loss = tf.reduce_mean(tf.square(features_img - target_img)) | |
if type == "content": | |
return 0.5 * loss | |
return loss | |
def _gram_matrix(self, input_tensor): | |
""" | |
Calculates the gram matrix and divides by the number of locations | |
input_tensor: | |
the output of the conv layer of the style image, shape = (batch_size, height, width, channels) | |
""" | |
result = tf.linalg.einsum("bijc,bijd->bcd", input_tensor, input_tensor) | |
input_shape = tf.shape(input_tensor) | |
num_locations = tf.cast(input_shape[1] * input_shape[2], tf.float32) | |
return result / (num_locations) | |
def get_features(self, image, type): | |
preprocess_image = self._preprocess_image(image) | |
outputs = self.feature_extractor(preprocess_image) | |
if type == "style": | |
outputs = outputs[self.n_content_layers : ] | |
features = [self._gram_matrix(style_output) for style_output in outputs] | |
elif type == "content": | |
features = outputs[ : self.n_content_layers] | |
return features | |
def _style_content_loss( | |
self, | |
style_targets, | |
style_outputs, | |
content_targets, | |
content_outputs, | |
style_weight, | |
content_weight, | |
): | |
""" | |
Calculates the total loss of the style transfer | |
style_targets: | |
the style features of the style image | |
style_outputs: | |
the style features of the generated image | |
content_targets: | |
the content features of the content image | |
content_outputs: | |
the content features of the generated image | |
style_weight: | |
the weight of the style loss | |
content_weight: | |
the weight of the content loss | |
""" | |
# adding the loss of each layer | |
style_loss = style_weight * tf.add_n( | |
[ | |
self._loss(style_target, style_output, type="style") | |
for style_target, style_output in zip(style_targets, style_outputs) | |
] | |
) | |
content_loss = content_weight * tf.add_n( | |
[ | |
self._loss(content_target, content_output, type="content") | |
for content_target, content_output in zip( | |
content_targets, content_outputs | |
) | |
] | |
) | |
total_loss = style_loss + content_loss | |
return total_loss, style_loss, content_loss | |
def _grad_loss( | |
self, | |
generated_image, | |
style_target, | |
content_target, | |
style_weight, | |
content_weight, | |
var_weight, | |
): | |
""" | |
Calculates the gradients of the loss function with respect to the generated image | |
generated_image: | |
the generated image | |
""" | |
with tf.GradientTape() as tape: | |
style_features = self.get_features(generated_image, type="style") | |
content_features = self.get_features(generated_image, type="content") | |
loss, style_loss, content_loss = self._style_content_loss( | |
style_target, | |
style_features, | |
content_target, | |
content_features, | |
style_weight, | |
content_weight, | |
) | |
variational_loss= var_weight*tf.image.total_variation(generated_image) | |
loss += variational_loss | |
grads = tape.gradient(loss, generated_image) | |
return grads, loss, [style_loss, content_loss, variational_loss] | |
def _update_image_with_style( | |
self, | |
generated_image, | |
style_target, | |
content_target, | |
style_weight, | |
content_weight, | |
optimizer, | |
var_weight, | |
): | |
grads, loss, loss_list = self._grad_loss( | |
generated_image, style_target, content_target, style_weight, content_weight, var_weight | |
) | |
optimizer.apply_gradients([(grads, generated_image)]) | |
generated_image.assign( | |
tf.clip_by_value(generated_image, clip_value_min=0.0, clip_value_max=255.0) | |
) | |
return loss_list |