johnowhitaker's picture
Create app.py
08ac7ce
raw
history blame
1.03 kB
import gradio as gr
from fastai.vision.all import *
from os.path import file_exists
import requests
model_fn = 'quick_224px'
if not file_exists(model_fn):
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(model_fn, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
# Load the model
def open_img(fn:Path): return Image.open(fn).convert('RGB').copy()
def get_images(path):
return get_image_files(path/'train') + get_image_files(path/'test')
def get_x(item):
return np.array(open_img(item).crop((0, 0, 512, 512)))
def get_y(item):
return np.array(open_img(item).crop((512, 0, 1024, 512)))
sketch_model = load_learner(model_fn)
def sketchify(image_path):
pred = sketch_model.predict(image_path)
np_im = pred[0].permute(1, 2, 0).numpy()
return np_im
iface = gr.Interface(fn=sketchify,
inputs=[gr.inputs.Image(shape=(512, 512), type="filepath")],
outputs=[gr.outputs.Image(type="numpy", label="Output Image")]
)
iface.launch()