File size: 2,382 Bytes
ef164a1
d44e389
 
dab5d0e
d44e389
 
dab5d0e
 
d44e389
 
 
 
ef164a1
 
 
 
 
d44e389
24cab13
dab5d0e
e0067df
dab5d0e
 
 
d44e389
dab5d0e
d44e389
 
dab5d0e
d44e389
dab5d0e
d44e389
dab5d0e
 
d44e389
dab5d0e
d44e389
 
 
 
8b21fe0
 
 
 
 
d44e389
dab5d0e
 
 
24cab13
b74f8d7
d44e389
 
 
08d07b3
 
4cee7a5
08d07b3
d44e389
 
4cee7a5
dab5d0e
 
 
 
 
 
 
 
 
8b21fe0
dab5d0e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging
import os

import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_url, cached_download

from inference.face_detector import StatRetinaFaceDetector
from inference.model_pipeline import VSNetModelPipeline
from inference.onnx_model import ONNXModel

logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')

MODEL_IMG_SIZE = 256
usage_count = 35
def load_model():
    REPO_ID = "Podtekatel/JJBAGAN"
    FILENAME = "198_jjba_8_k_2_099_ep.onnx"

    global model
    global pipeline

    model_path = cached_download(
        hf_hub_url(REPO_ID, FILENAME), use_auth_token=os.getenv('HF_TOKEN')
    )
    model = ONNXModel(model_path)

    pipeline = VSNetModelPipeline(model, StatRetinaFaceDetector(MODEL_IMG_SIZE), background_resize=1024, no_detected_resize=1024)
    return model

load_model()

def inference(img):
    img = np.array(img)
    out_img = pipeline(img)
    out_img = Image.fromarray(out_img)

    global usage_count
    usage_count += 1
    logging.info(f'Usage count is {usage_count}')

    return out_img


title = "JJStyleTransfer"
description = "Gradio Demo for JoJo Bizzare Adventures style transfer. To use it, simply upload your image, or click one of the examples to load them. Press ❤️ if you like this space!"
article = "This is one of my successful experiments on style transfer. I've built my own pipeline, generator model and private dataset to train this model<br>" \
          "" \
          "" \
          "" \
          "Model pipeline which used in project is improved CartoonGAN.<br>" \
          "This model was trained on RTX 2080 Ti 1.5 days with batch size 7.<br>" \
          "Model weights 64 MB in ONNX fp32 format, infers 25 ms on GPU and 150 ms on CPU at 256x256 resolution.<br>" \
          "If you want to use this app or integrate this model into yours, please contact me at email '[email protected]'."

imgs_folder = 'demo'
examples = [[os.path.join(imgs_folder, img_filename)] for img_filename in sorted(os.listdir(imgs_folder))]

demo = gr.Interface(
    fn=inference,
    inputs=[gr.inputs.Image(type="pil")],
    outputs=gr.outputs.Image(type="pil"),
    title=title,
    description=description,
    article=article,
    examples=examples)
demo.queue(concurrency_count=1)
demo.launch()