DamarJati commited on
Commit
41de93f
·
verified ·
1 Parent(s): 7332dc9

Upload app (2).py

Browse files
Files changed (1) hide show
  1. modules/app (2).py +339 -0
modules/app (2).py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import gradio as gr
5
+ import huggingface_hub
6
+ import numpy as np
7
+ import onnxruntime as rt
8
+ import pandas as pd
9
+ from PIL import Image
10
+
11
+ TITLE = "WaifuDiffusion Tagger"
12
+ DESCRIPTION = """
13
+ Demo for the WaifuDiffusion tagger models
14
+
15
+ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
16
+ """
17
+
18
+ # Dataset v3 series of models:
19
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
+
25
+ # Dataset v2 series of models:
26
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
+
32
+ # Files to download from the repos
33
+ MODEL_FILENAME = "model.onnx"
34
+ LABEL_FILENAME = "selected_tags.csv"
35
+
36
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
37
+ kaomojis = [
38
+ "0_0",
39
+ "(o)_(o)",
40
+ "+_+",
41
+ "+_-",
42
+ "._.",
43
+ "<o>_<o>",
44
+ "<|>_<|>",
45
+ "=_=",
46
+ ">_<",
47
+ "3_3",
48
+ "6_9",
49
+ ">_o",
50
+ "@_@",
51
+ "^_^",
52
+ "o_o",
53
+ "u_u",
54
+ "x_x",
55
+ "|_|",
56
+ "||_||",
57
+ ]
58
+
59
+
60
+ def parse_args() -> argparse.Namespace:
61
+ parser = argparse.ArgumentParser()
62
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
63
+ parser.add_argument("--score-general-threshold", type=float, default=0.35)
64
+ parser.add_argument("--score-character-threshold", type=float, default=0.85)
65
+ parser.add_argument("--share", action="store_true")
66
+ return parser.parse_args()
67
+
68
+
69
+ def load_labels(dataframe) -> list[str]:
70
+ name_series = dataframe["name"]
71
+ name_series = name_series.map(
72
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
73
+ )
74
+ tag_names = name_series.tolist()
75
+
76
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
77
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
78
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
79
+ return tag_names, rating_indexes, general_indexes, character_indexes
80
+
81
+
82
+ def mcut_threshold(probs):
83
+ """
84
+ Maximum Cut Thresholding (MCut)
85
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
86
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
87
+ (pp. 172-183).
88
+ """
89
+ sorted_probs = probs[probs.argsort()[::-1]]
90
+ difs = sorted_probs[:-1] - sorted_probs[1:]
91
+ t = difs.argmax()
92
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
93
+ return thresh
94
+
95
+
96
+ class Predictor:
97
+ def __init__(self):
98
+ self.model_target_size = None
99
+ self.last_loaded_repo = None
100
+
101
+ def download_model(self, model_repo):
102
+ csv_path = huggingface_hub.hf_hub_download(
103
+ model_repo,
104
+ LABEL_FILENAME,
105
+ )
106
+ model_path = huggingface_hub.hf_hub_download(
107
+ model_repo,
108
+ MODEL_FILENAME,
109
+ )
110
+ return csv_path, model_path
111
+
112
+ def load_model(self, model_repo):
113
+ if model_repo == self.last_loaded_repo:
114
+ return
115
+
116
+ csv_path, model_path = self.download_model(model_repo)
117
+
118
+ tags_df = pd.read_csv(csv_path)
119
+ sep_tags = load_labels(tags_df)
120
+
121
+ self.tag_names = sep_tags[0]
122
+ self.rating_indexes = sep_tags[1]
123
+ self.general_indexes = sep_tags[2]
124
+ self.character_indexes = sep_tags[3]
125
+
126
+ model = rt.InferenceSession(model_path)
127
+ _, height, width, _ = model.get_inputs()[0].shape
128
+ self.model_target_size = height
129
+
130
+ self.last_loaded_repo = model_repo
131
+ self.model = model
132
+
133
+ def prepare_image(self, image):
134
+ target_size = self.model_target_size
135
+
136
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
137
+ canvas.alpha_composite(image)
138
+ image = canvas.convert("RGB")
139
+
140
+ # Pad image to square
141
+ image_shape = image.size
142
+ max_dim = max(image_shape)
143
+ pad_left = (max_dim - image_shape[0]) // 2
144
+ pad_top = (max_dim - image_shape[1]) // 2
145
+
146
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
147
+ padded_image.paste(image, (pad_left, pad_top))
148
+
149
+ # Resize
150
+ if max_dim != target_size:
151
+ padded_image = padded_image.resize(
152
+ (target_size, target_size),
153
+ Image.BICUBIC,
154
+ )
155
+
156
+ # Convert to numpy array
157
+ image_array = np.asarray(padded_image, dtype=np.float32)
158
+
159
+ # Convert PIL-native RGB to BGR
160
+ image_array = image_array[:, :, ::-1]
161
+
162
+ return np.expand_dims(image_array, axis=0)
163
+
164
+ def predict(
165
+ self,
166
+ image,
167
+ model_repo,
168
+ general_thresh,
169
+ general_mcut_enabled,
170
+ character_thresh,
171
+ character_mcut_enabled,
172
+ ):
173
+ self.load_model(model_repo)
174
+
175
+ image = self.prepare_image(image)
176
+
177
+ input_name = self.model.get_inputs()[0].name
178
+ label_name = self.model.get_outputs()[0].name
179
+ preds = self.model.run([label_name], {input_name: image})[0]
180
+
181
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
182
+
183
+ # First 4 labels are actually ratings: pick one with argmax
184
+ ratings_names = [labels[i] for i in self.rating_indexes]
185
+ rating = dict(ratings_names)
186
+
187
+ # Then we have general tags: pick any where prediction confidence > threshold
188
+ general_names = [labels[i] for i in self.general_indexes]
189
+
190
+ if general_mcut_enabled:
191
+ general_probs = np.array([x[1] for x in general_names])
192
+ general_thresh = mcut_threshold(general_probs)
193
+
194
+ general_res = [x for x in general_names if x[1] > general_thresh]
195
+ general_res = dict(general_res)
196
+
197
+ # Everything else is characters: pick any where prediction confidence > threshold
198
+ character_names = [labels[i] for i in self.character_indexes]
199
+
200
+ if character_mcut_enabled:
201
+ character_probs = np.array([x[1] for x in character_names])
202
+ character_thresh = mcut_threshold(character_probs)
203
+ character_thresh = max(0.15, character_thresh)
204
+
205
+ character_res = [x for x in character_names if x[1] > character_thresh]
206
+ character_res = dict(character_res)
207
+
208
+ sorted_general_strings = sorted(
209
+ general_res.items(),
210
+ key=lambda x: x[1],
211
+ reverse=True,
212
+ )
213
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
214
+ sorted_general_strings = (
215
+ ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
216
+ )
217
+
218
+ return sorted_general_strings, rating, character_res, general_res
219
+
220
+
221
+ def main():
222
+ args = parse_args()
223
+
224
+ predictor = Predictor()
225
+
226
+ dropdown_list = [
227
+ SWINV2_MODEL_DSV3_REPO,
228
+ CONV_MODEL_DSV3_REPO,
229
+ VIT_MODEL_DSV3_REPO,
230
+ VIT_LARGE_MODEL_DSV3_REPO,
231
+ EVA02_LARGE_MODEL_DSV3_REPO,
232
+ MOAT_MODEL_DSV2_REPO,
233
+ SWIN_MODEL_DSV2_REPO,
234
+ CONV_MODEL_DSV2_REPO,
235
+ CONV2_MODEL_DSV2_REPO,
236
+ VIT_MODEL_DSV2_REPO,
237
+ ]
238
+
239
+ with gr.Blocks(title=TITLE) as demo:
240
+ with gr.Column():
241
+ gr.Markdown(
242
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
243
+ )
244
+ gr.Markdown(value=DESCRIPTION)
245
+ with gr.Row():
246
+ with gr.Column(variant="panel"):
247
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
248
+ model_repo = gr.Dropdown(
249
+ dropdown_list,
250
+ value=SWINV2_MODEL_DSV3_REPO,
251
+ label="Model",
252
+ )
253
+ with gr.Row():
254
+ general_thresh = gr.Slider(
255
+ 0,
256
+ 1,
257
+ step=args.score_slider_step,
258
+ value=args.score_general_threshold,
259
+ label="General Tags Threshold",
260
+ scale=3,
261
+ )
262
+ general_mcut_enabled = gr.Checkbox(
263
+ value=False,
264
+ label="Use MCut threshold",
265
+ scale=1,
266
+ )
267
+ with gr.Row():
268
+ character_thresh = gr.Slider(
269
+ 0,
270
+ 1,
271
+ step=args.score_slider_step,
272
+ value=args.score_character_threshold,
273
+ label="Character Tags Threshold",
274
+ scale=3,
275
+ )
276
+ character_mcut_enabled = gr.Checkbox(
277
+ value=False,
278
+ label="Use MCut threshold",
279
+ scale=1,
280
+ )
281
+ with gr.Row():
282
+ clear = gr.ClearButton(
283
+ components=[
284
+ image,
285
+ model_repo,
286
+ general_thresh,
287
+ general_mcut_enabled,
288
+ character_thresh,
289
+ character_mcut_enabled,
290
+ ],
291
+ variant="secondary",
292
+ size="lg",
293
+ )
294
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
295
+ with gr.Column(variant="panel"):
296
+ sorted_general_strings = gr.Textbox(label="Output (string)")
297
+ rating = gr.Label(label="Rating")
298
+ character_res = gr.Label(label="Output (characters)")
299
+ general_res = gr.Label(label="Output (tags)")
300
+ clear.add(
301
+ [
302
+ sorted_general_strings,
303
+ rating,
304
+ character_res,
305
+ general_res,
306
+ ]
307
+ )
308
+
309
+ submit.click(
310
+ predictor.predict,
311
+ inputs=[
312
+ image,
313
+ model_repo,
314
+ general_thresh,
315
+ general_mcut_enabled,
316
+ character_thresh,
317
+ character_mcut_enabled,
318
+ ],
319
+ outputs=[sorted_general_strings, rating, character_res, general_res],
320
+ )
321
+
322
+ gr.Examples(
323
+ [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
324
+ inputs=[
325
+ image,
326
+ model_repo,
327
+ general_thresh,
328
+ general_mcut_enabled,
329
+ character_thresh,
330
+ character_mcut_enabled,
331
+ ],
332
+ )
333
+
334
+ demo.queue(max_size=10)
335
+ demo.launch()
336
+
337
+
338
+ if __name__ == "__main__":
339
+ main()