Muinez commited on
Commit
31ac1b2
1 Parent(s): 0151948

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +33 -0
  2. dbimutils.py +68 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, ConvNextV2ForImageClassification
4
+ from transformers import AutoModelForImageClassification
5
+ from torch import nn
6
+ import dbimutils as utils
7
+
8
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+
10
+ image_processor = AutoImageProcessor.from_pretrained("Muinez/artwork-scorer")
11
+ model = AutoModelForImageClassification.from_pretrained("Muinez/artwork-scorer", problem_type="multi_label_classification").to(DEVICE)
12
+
13
+ def predict(img):
14
+ file = utils.preprocess_image(img)
15
+ encoded = image_processor(file, return_tensors="pt").to(DEVICE)
16
+
17
+ with torch.no_grad():
18
+ logits = model(**encoded).logits.cpu()
19
+
20
+ outputs = nn.functional.sigmoid(logits)
21
+
22
+ return outputs[0][0], outputs[0][1]
23
+
24
+ gr.Interface(
25
+ title="Artwork scorer",
26
+ description="Predicts score (0-1) for artwork.\nCould be wrong!!!\nDoes not work very well with nsfw i.e. it was not trained on it",
27
+ fn=predict,
28
+ allow_flagging="never",
29
+ inputs=gr.Image(type="pil"),
30
+ outputs=[gr.Number(label="Score"), gr.Number(label="View count ratio (probably useless)")]
31
+ ).launch()
32
+
33
+
dbimutils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DanBooru IMage Utility functions
2
+ # Taken from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ import PIL
8
+
9
+ def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
10
+ if img.endswith(".gif"):
11
+ img = Image.open(img)
12
+ img = img.convert("RGB")
13
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
14
+ else:
15
+ img = cv2.imread(img, flag)
16
+ return img
17
+
18
+
19
+ def smart_24bit(img):
20
+ if img.dtype is np.dtype(np.uint16):
21
+ img = (img / 257).astype(np.uint8)
22
+
23
+ if len(img.shape) == 2:
24
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
25
+ elif img.shape[2] == 4:
26
+ trans_mask = img[:, :, 3] == 0
27
+ img[trans_mask] = [255, 255, 255, 255]
28
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
29
+ return img
30
+
31
+
32
+ def make_square(img, target_size):
33
+ old_size = img.shape[:2]
34
+ desired_size = max(old_size)
35
+ desired_size = max(desired_size, target_size)
36
+
37
+ delta_w = desired_size - old_size[1]
38
+ delta_h = desired_size - old_size[0]
39
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
40
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
41
+
42
+ color = [255, 255, 255]
43
+ new_im = cv2.copyMakeBorder(
44
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
45
+ )
46
+ return new_im
47
+
48
+
49
+ def smart_resize(img, size):
50
+ # Assumes the image has already gone through make_square
51
+ if img.shape[0] > size:
52
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
53
+ elif img.shape[0] < size:
54
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
55
+ return img
56
+
57
+ def preprocess_image(img):
58
+ image = img.convert('RGBA')
59
+ new_image = PIL.Image.new('RGBA', image.size, 'WHITE')
60
+ new_image.paste(image, mask=image)
61
+ image = new_image.convert('RGB')
62
+ image = np.asarray(image)
63
+
64
+ image = make_square(image, 384)
65
+ image = smart_resize(image, 384)
66
+ image = image.astype(np.float32)
67
+
68
+ return Image.fromarray(np.uint8(image))