DamarJati commited on
Commit
6b0140b
·
verified ·
1 Parent(s): 1ef4f73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -202
app.py CHANGED
@@ -44,210 +44,12 @@ def controlnet_process_func(image, controlnet_type, model):
44
  def intpaint_func (image, controlnet_type, model):
45
  # Update fungsi sesuai kebutuhan
46
  return controlnet_process(image, controlnet_type, model)
47
-
48
-
49
-
50
- #wd tagger
51
-
52
- # Dataset v3 series of models:
53
- SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
54
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
55
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
56
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
57
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
58
-
59
- # Dataset v2 series of models:
60
- MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
61
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
62
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
63
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
64
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
65
-
66
- # Files to download from the repos
67
- MODEL_FILENAME = "model.onnx"
68
- LABEL_FILENAME = "selected_tags.csv"
69
-
70
- # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
71
- kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ]
72
-
73
- def parse_args() -> argparse.Namespace:
74
- parser = argparse.ArgumentParser()
75
- parser.add_argument("--score-slider-step", type=float, default=0.05)
76
- parser.add_argument("--score-general-threshold", type=float, default=0.35)
77
- parser.add_argument("--score-character-threshold", type=float, default=0.85)
78
- parser.add_argument("--share", action="store_true")
79
- return parser.parse_args()
80
-
81
-
82
- def load_labels(dataframe) -> list[str]:
83
- name_series = dataframe["name"]
84
- name_series = name_series.map(
85
- lambda x: x.replace("_", " ") if x not in kaomojis else x
86
- )
87
- tag_names = name_series.tolist()
88
-
89
- rating_indexes = list(np.where(dataframe["category"] == 9)[0])
90
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
91
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
92
- return tag_names, rating_indexes, general_indexes, character_indexes
93
-
94
-
95
- def mcut_threshold(probs):
96
- """
97
- Maximum Cut Thresholding (MCut)
98
- Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
99
- for Multi-label Classification. In 11th International Symposium, IDA 2012
100
- (pp. 172-183).
101
- """
102
- sorted_probs = probs[probs.argsort()[::-1]]
103
- difs = sorted_probs[:-1] - sorted_probs[1:]
104
- t = difs.argmax()
105
- thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
106
- return thresh
107
-
108
-
109
- class Predictor:
110
- def __init__(self):
111
- self.model_target_size = None
112
- self.last_loaded_repo = None
113
-
114
- def download_model(self, model_repo):
115
- csv_path = huggingface_hub.hf_hub_download(
116
- model_repo,
117
- LABEL_FILENAME,
118
- )
119
- model_path = huggingface_hub.hf_hub_download(
120
- model_repo,
121
- MODEL_FILENAME,
122
- )
123
- return csv_path, model_path
124
-
125
- def load_model(self, model_repo):
126
- if model_repo == self.last_loaded_repo:
127
- return
128
-
129
- csv_path, model_path = self.download_model(model_repo)
130
-
131
- tags_df = pd.read_csv(csv_path)
132
- sep_tags = load_labels(tags_df)
133
-
134
- self.tag_names = sep_tags[0]
135
- self.rating_indexes = sep_tags[1]
136
- self.general_indexes = sep_tags[2]
137
- self.character_indexes = sep_tags[3]
138
-
139
- model = rt.InferenceSession(model_path)
140
- _, height, width, _ = model.get_inputs()[0].shape
141
- self.model_target_size = height
142
-
143
- self.last_loaded_repo = model_repo
144
- self.model = model
145
-
146
- def prepare_image(self, image):
147
- target_size = self.model_target_size
148
-
149
- canvas = Image.new("RGBA", image.size, (255, 255, 255))
150
- canvas.alpha_composite(image)
151
- image = canvas.convert("RGB")
152
-
153
- # Pad image to square
154
- image_shape = image.size
155
- max_dim = max(image_shape)
156
- pad_left = (max_dim - image_shape[0]) // 2
157
- pad_top = (max_dim - image_shape[1]) // 2
158
-
159
- padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
160
- padded_image.paste(image, (pad_left, pad_top))
161
-
162
- # Resize
163
- if max_dim != target_size:
164
- padded_image = padded_image.resize(
165
- (target_size, target_size),
166
- Image.BICUBIC,
167
- )
168
-
169
- # Convert to numpy array
170
- image_array = np.asarray(padded_image, dtype=np.float32)
171
-
172
- # Convert PIL-native RGB to BGR
173
- image_array = image_array[:, :, ::-1]
174
-
175
- return np.expand_dims(image_array, axis=0)
176
-
177
- @spaces.GPU()
178
- def predict(
179
- self,
180
- image,
181
- model_repo,
182
- general_thresh,
183
- general_mcut_enabled,
184
- character_thresh,
185
- character_mcut_enabled,
186
- ):
187
- self.load_model(model_repo)
188
-
189
- image = self.prepare_image(image)
190
-
191
- input_name = self.model.get_inputs()[0].name
192
- label_name = self.model.get_outputs()[0].name
193
- preds = self.model.run([label_name], {input_name: image})[0]
194
-
195
- labels = list(zip(self.tag_names, preds[0].astype(float)))
196
-
197
- # First 4 labels are actually ratings: pick one with argmax
198
- ratings_names = [labels[i] for i in self.rating_indexes]
199
- rating = dict(ratings_names)
200
-
201
- # Then we have general tags: pick any where prediction confidence > threshold
202
- general_names = [labels[i] for i in self.general_indexes]
203
-
204
- if general_mcut_enabled:
205
- general_probs = np.array([x[1] for x in general_names])
206
- general_thresh = mcut_threshold(general_probs)
207
-
208
- general_res = [x for x in general_names if x[1] > general_thresh]
209
- general_res = dict(general_res)
210
-
211
- # Everything else is characters: pick any where prediction confidence > threshold
212
- character_names = [labels[i] for i in self.character_indexes]
213
-
214
- if character_mcut_enabled:
215
- character_probs = np.array([x[1] for x in character_names])
216
- character_thresh = mcut_threshold(character_probs)
217
- character_thresh = max(0.15, character_thresh)
218
-
219
- character_res = [x for x in character_names if x[1] > character_thresh]
220
- character_res = dict(character_res)
221
-
222
- sorted_general_strings = sorted(
223
- general_res.items(),
224
- key=lambda x: x[1],
225
- reverse=True,
226
- )
227
- sorted_general_strings = [x[0] for x in sorted_general_strings]
228
- sorted_general_strings = (
229
- ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
230
- )
231
-
232
- return sorted_general_strings, rating, character_res, general_res
233
-
234
-
235
 
236
- args = parse_args()
237
- predictor = Predictor()
238
 
239
- dropdown_list = [
240
- SWINV2_MODEL_DSV3_REPO,
241
- CONV_MODEL_DSV3_REPO,
242
- VIT_MODEL_DSV3_REPO,
243
- VIT_LARGE_MODEL_DSV3_REPO,
244
- EVA02_LARGE_MODEL_DSV3_REPO,
245
- MOAT_MODEL_DSV2_REPO,
246
- SWIN_MODEL_DSV2_REPO,
247
- CONV_MODEL_DSV2_REPO,
248
- CONV2_MODEL_DSV2_REPO,
249
- VIT_MODEL_DSV2_REPO,
250
- ]
251
 
252
  with gr.Blocks(css= "style.css") as app:
253
  # Dropdown untuk memilih model di luar tab dengan lebar kecil
 
44
  def intpaint_func (image, controlnet_type, model):
45
  # Update fungsi sesuai kebutuhan
46
  return controlnet_process(image, controlnet_type, model)
47
+
48
+ def intpaint_func (image, controlnet_type, model):
49
+ # Update fungsi sesuai kebutuhan
50
+ return controlnet_process(image, controlnet_type, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  with gr.Blocks(css= "style.css") as app:
55
  # Dropdown untuk memilih model di luar tab dengan lebar kecil