Spaces:
Runtime error
Runtime error
import gradio as gr | |
from model import NeuralStyleTransfer | |
import tensorflow as tf | |
def model_fn( | |
style, content, extractor="inception_v3", n_content_layers=3, n_style_layers=2, | |
epochs=4, learning_rate=60.0, steps_per_epoch=100, style_weight=1e-2, | |
): | |
model = NeuralStyleTransfer( | |
style_image=style, | |
content_image=content, | |
extractor=extractor, | |
n_content_layers=n_content_layers, | |
n_style_layers=n_style_layers, | |
) | |
return model.fit_style_transfer( | |
epochs=10, | |
learning_rate=80.0, | |
steps_per_epoch=100, | |
style_weight=1e-2, | |
content_weight=1e-4, | |
show_image=True, | |
show_interval=90, | |
var_weight=1e-12, | |
terminal=False, | |
) | |
def hugging_face(): | |
demo = gr.Interface( | |
fn=model_fn, | |
inputs=[ | |
"image", | |
"image", | |
gr.Dropdown( | |
["inception_v3", "vgg19", "resnet50", "mobilenet_v2"], | |
label="extractor", | |
default="inception_v3", | |
info="Feature extractor to use.", | |
), | |
gr.Slider( | |
1, | |
5, | |
value=3, | |
label="n_content_layers", | |
info="Number of content layers to use.", | |
), | |
gr.Slider( | |
1, | |
5, | |
value=2, | |
label="n_style_layers", | |
info="Number of style layers to use.", | |
), | |
gr.Slider( | |
2, 20, value=4, label="epochs", info="Number of epochs to train for." | |
), | |
gr.Slider( | |
1, 100, value=60, label="learning_rate", info="Initial Learning rate." | |
), | |
gr.Slider( | |
1, | |
100, | |
value=100, | |
label="steps_per_epoch", | |
info="Number of steps per epoch.", | |
), | |
gr.Slider( | |
1e-4, | |
1e-2, | |
value=1e-2, | |
label="style_weight", | |
info="Weight of style loss.", | |
), | |
gr.Slider( | |
1e-4, | |
1e-2, | |
value=1e-4, | |
label="content_weight", | |
info="Weight of content loss.", | |
), | |
gr.Slider( | |
1e-12, | |
1e-9, | |
value=1e-12, | |
label="var_weight", | |
info="Weight of total variation loss.", | |
), | |
], | |
outputs="image", | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = hugging_face() | |
demo.launch( ) | |