import os from functools import lru_cache import gradio as gr import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from imgutils.data import load_image from imgutils.utils import open_onnx_model @lru_cache() def _onnx_model(): return open_onnx_model(hf_hub_download( 'deepghs/imgutils-models', 'nsfw/nsfwjs.onnx' )) def _image_preprocess(image, size: int = 224) -> np.ndarray: image = load_image(image, mode='RGB').resize((size, size), Image.NEAREST) return (np.array(image) / 255.0)[None, ...] _LABELS = ['drawings', 'hentai', 'neutral', 'porn', 'sexy'] def predict(image): input_ = _image_preprocess(image).astype(np.float32) output_, = _onnx_model().run(['dense_3'], {'input_1': input_}) return dict(zip(_LABELS, map(float, output_[0]))) if __name__ == '__main__': with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr_input_image = gr.Image(type='pil', label='Original Image') gr_btn_submit = gr.Button(value='Tagging', variant='primary') with gr.Column(): gr_ratings = gr.Label(label='Ratings') gr_btn_submit.click( predict, inputs=[gr_input_image], outputs=[gr_ratings], ) demo.queue(os.cpu_count()).launch()