tori29umai commited on
Commit
3542be4
·
1 Parent(s): 9330b76
Files changed (7) hide show
  1. app.py +154 -0
  2. config.json +57 -0
  3. requirements.txt +21 -0
  4. utils/prompt_analysis.py +41 -0
  5. utils/prompt_utils.py +28 -0
  6. utils/tagger.py +149 -0
  7. utils/utils.py +76 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from gradio_imageslider import ImageSlider
4
+ import torch
5
+
6
+ torch.jit.script = lambda f: f
7
+ from diffusers import (
8
+ ControlNetModel,
9
+ StableDiffusionXLControlNetImg2ImgPipeline,
10
+ DDIMScheduler,
11
+ )
12
+ from controlnet_aux import AnylineDetector
13
+ from compel import Compel, ReturnedEmbeddingsType
14
+ from PIL import Image
15
+ import os
16
+ import time
17
+ import numpy as np
18
+
19
+ from utils.utils import load_cn_model, load_cn_config, load_tagger_model, resize_image_aspect_ratio, base_generation
20
+ from utils.prompt_analysis import PromptAnalysis
21
+
22
+ path = os.getcwd()
23
+ cn_dir = f"{path}/controlnet"
24
+ tagger_dir = f"{path}/tagger"
25
+
26
+ load_cn_model(cn_dir)
27
+ load_cn_config(cn_dir)
28
+ load_tagger_model(tagger_dir)
29
+
30
+ IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
31
+ IS_SPACE = os.environ.get("SPACE_ID", None) is not None
32
+
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ dtype = torch.float16
35
+
36
+ LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
37
+
38
+ print(f"device: {device}")
39
+ print(f"dtype: {dtype}")
40
+ print(f"low memory: {LOW_MEMORY}")
41
+
42
+
43
+ model = "cagliostrolab/animagine-xl-3.1"
44
+ scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler")
45
+ controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=torch.float16, use_safetensors=True)
46
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
47
+ model,
48
+ controlnet=controlnet,
49
+ torch_dtype=dtype,
50
+ variant="fp16",
51
+ use_safetensors=True,
52
+ scheduler=scheduler,
53
+ )
54
+
55
+ compel = Compel(
56
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
57
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
58
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
59
+ requires_pooled=[False, True],
60
+ )
61
+ pipe = pipe.to(device)
62
+
63
+
64
+
65
+ @spaces.GPU
66
+ def predict(
67
+ input_image,
68
+ prompt,
69
+ negative_prompt,
70
+ controlnet_conditioning_scale,
71
+ ):
72
+ base_size =input_image.size
73
+ resize_image= resize_image_aspect_ratio(input_image)
74
+ resize_image_size = resize_image.size
75
+ width = resize_image_size[0]
76
+ height = resize_image_size[1]
77
+ white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
78
+ conditioning, pooled = compel([prompt, negative_prompt])
79
+ generator = torch.manual_seed(0)
80
+ last_time = time.time()
81
+
82
+ output_image = pipe(
83
+ image=white_base_pil,
84
+ control_image=resize_image,
85
+ strength=1.0,
86
+ prompt_embeds=conditioning[0:1],
87
+ pooled_prompt_embeds=pooled[0:1],
88
+ negative_prompt_embeds=conditioning[1:2],
89
+ negative_pooled_prompt_embeds=pooled[1:2],
90
+ width=width,
91
+ height=height,
92
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
93
+ controlnet_start=0.0,
94
+ controlnet_end=1.0,
95
+ generator=generator,
96
+ num_inference_steps=30,
97
+ guidance_scale=8.5,
98
+ eta=1.0,
99
+ )
100
+ print(f"Time taken: {time.time() - last_time}")
101
+ output_image = output_image.resize(base_size, Image.LANCZOS)
102
+ return output_image
103
+
104
+
105
+ css = """
106
+ #intro{
107
+ # max-width: 32rem;
108
+ # text-align: center;
109
+ # margin: 0 auto;
110
+ }
111
+ """
112
+
113
+ with gr.Blocks(css=css) as demo:
114
+ with gr.Row() as block:
115
+ with gr.Column():
116
+ # 画像アップロード用の行
117
+ with gr.Row():
118
+ with gr.Column():
119
+ input_image = gr.Image(label="入力画像", type="pil")
120
+
121
+ # プロンプト入力用の行
122
+ with gr.Row():
123
+ prompt_analysis = PromptAnalysis(tagger_dir)
124
+ [prompt, nega] = PromptAnalysis.layout(input_image)
125
+ # 画像の詳細設定用のスライダー行
126
+ with gr.Row():
127
+ controlnet_conditioning_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, interactive=True, label="ラインアートの忠実度")
128
+
129
+ # 画像生成ボタンの行
130
+ with gr.Row():
131
+ generate_button = gr.Button("生成", interactive=False)
132
+
133
+ with gr.Column():
134
+ output_image = gr.Image(type="pil", label="Output Image")
135
+
136
+ # インプットとアウトプットの設定
137
+ inputs = [
138
+ input_image,
139
+ prompt,
140
+ nega,
141
+ controlnet_conditioning_scale,
142
+ ]
143
+ outputs = [output_image]
144
+
145
+ # ボタンのクリックイベントを設定
146
+ generate_button.click(
147
+ fn=predict,
148
+ inputs=[input_image, prompt, nega, controlnet_conditioning_scale],
149
+ outputs=[output_image]
150
+ )
151
+
152
+ # デモの設定と起動
153
+ demo.queue(api_open=True)
154
+ demo.launch(show_api=True)
config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.27.2",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": "text_time",
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": 256,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20
12
+ ],
13
+ "block_out_channels": [
14
+ 320,
15
+ 640,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_channels": 3,
20
+ "conditioning_embedding_out_channels": [
21
+ 16,
22
+ 32,
23
+ 96,
24
+ 256
25
+ ],
26
+ "controlnet_conditioning_channel_order": "rgb",
27
+ "cross_attention_dim": 2048,
28
+ "down_block_types": [
29
+ "DownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D"
32
+ ],
33
+ "downsample_padding": 1,
34
+ "encoder_hid_dim": null,
35
+ "encoder_hid_dim_type": null,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "projection_class_embeddings_input_dim": 2816,
49
+ "resnet_time_scale_shift": "default",
50
+ "transformer_layers_per_block": [
51
+ 1,
52
+ 2,
53
+ 10
54
+ ],
55
+ "upcast_attention": null,
56
+ "use_linear_projection": true
57
+ }
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.29.0
2
+ accelerate
3
+ transformers
4
+ torchvision
5
+ xformers
6
+ accelerate
7
+ invisible-watermark
8
+ huggingface-hub
9
+ hf-transfer
10
+ gradio_imageslider==0.0.20
11
+ compel
12
+ opencv-python
13
+ numpy
14
+ diffusers==0.27.0
15
+ transformers
16
+ accelerate
17
+ safetensors
18
+ hidiffusion==0.1.8
19
+ spaces
20
+ torch==2.2
21
+ controlnet-aux==0.0.9
utils/prompt_analysis.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ from utils.prompt_utils import remove_color
6
+ from utils.tagger import modelLoad, analysis
7
+
8
+
9
+ class PromptAnalysis:
10
+ def __init__(self, app_config, post_filter=True,
11
+ default_nagative_prompt="lowres, error, extra digit, fewer digits, cropped, worst quality, "
12
+ "low quality, normal quality, jpeg artifacts, blurry"):
13
+ self.default_nagative_prompt = default_nagative_prompt
14
+ self.post_filter = post_filter
15
+ self.model = None
16
+ self.model_dir = os.path.join(app_config.dpath, 'models/tagger')
17
+
18
+ def layout(self, lang_util, input_image):
19
+ with gr.Column():
20
+ with gr.Row():
21
+ self.prompt = gr.Textbox(label=lang_util.get_text("prompt"), lines=3)
22
+ with gr.Row():
23
+ self.negative_prompt = gr.Textbox(label=lang_util.get_text("negative_prompt"), lines=3, value=self.default_nagative_prompt)
24
+ with gr.Row():
25
+ self.prompt_analysis_button = gr.Button(lang_util.get_text("analyze_prompt"))
26
+
27
+ self.prompt_analysis_button.click(
28
+ self.process_prompt_analysis,
29
+ inputs=[input_image],
30
+ outputs=self.prompt
31
+ )
32
+ return [self.prompt, self.negative_prompt]
33
+
34
+ def process_prompt_analysis(self, input_image_path):
35
+ if self.model is None:
36
+ self.model = modelLoad(self.model_dir)
37
+ tags = analysis(input_image_path, self.model_dir, self.model)
38
+ tags_list = tags
39
+ if self.post_filter:
40
+ tags_list = remove_color(tags)
41
+ return tags_list
utils/prompt_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def remove_duplicates(base_prompt):
2
+ # タグの重複を取り除く
3
+ prompt_list = base_prompt.split(", ")
4
+ seen = set()
5
+ unique_tags = []
6
+ for tag in prompt_list :
7
+ tag_clean = tag.lower().strip()
8
+ if tag_clean not in seen and tag_clean != "":
9
+ unique_tags.append(tag)
10
+ seen.add(tag_clean)
11
+ return ", ".join(unique_tags)
12
+
13
+
14
+ def remove_color(base_prompt):
15
+ # タグの色情報を取り除く
16
+ prompt_list = base_prompt.split(", ")
17
+ color_list = ["pink", "red", "orange", "brown", "yellow", "green", "blue", "purple", "blonde", "colored skin", "white hair"]
18
+ # カラータグを除去します。
19
+ cleaned_tags = [tag for tag in prompt_list if all(color.lower() not in tag.lower() for color in color_list)]
20
+ return ", ".join(cleaned_tags)
21
+
22
+
23
+ def execute_prompt(execute_tags, base_prompt):
24
+ prompt_list = base_prompt.split(", ")
25
+ # execute_tagsを除去
26
+ filtered_tags = [tag for tag in prompt_list if tag not in execute_tags]
27
+ # 最終的なプロンプトを生成
28
+ return ", ".join(filtered_tags)
utils/tagger.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # https://github.com/kohya-ss/sd-scripts/blob/main/finetune/tag_images_by_wd14_tagger.py
3
+
4
+ import csv
5
+ import os
6
+ os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
7
+
8
+ from PIL import Image
9
+ import cv2
10
+ import numpy as np
11
+ from pathlib import Path
12
+ import onnx
13
+ import onnxruntime as ort
14
+
15
+ # from wd14 tagger
16
+ IMAGE_SIZE = 448
17
+
18
+ model = None # Initialize model variable
19
+
20
+
21
+ def convert_array_to_bgr(array):
22
+ """
23
+ Convert a NumPy array image to BGR format regardless of its original format.
24
+
25
+ Parameters:
26
+ - array: NumPy array of the image.
27
+
28
+ Returns:
29
+ - A NumPy array representing the image in BGR format.
30
+ """
31
+ # グレースケール画像(2次元配列)
32
+ if array.ndim == 2:
33
+ # グレースケールをBGRに変換(3チャンネルに拡張)
34
+ bgr_array = np.stack((array,) * 3, axis=-1)
35
+ # RGBAまたはRGB画像(3次元配列)
36
+ elif array.ndim == 3:
37
+ # RGBA画像の場合、アルファチャンネルを削除
38
+ if array.shape[2] == 4:
39
+ array = array[:, :, :3]
40
+ # RGBをBGRに変換
41
+ bgr_array = array[:, :, ::-1]
42
+ else:
43
+ raise ValueError("Unsupported array shape.")
44
+
45
+ return bgr_array
46
+
47
+
48
+ def preprocess_image(image):
49
+ image = np.array(image)
50
+ image = convert_array_to_bgr(image)
51
+
52
+ size = max(image.shape[0:2])
53
+ pad_x = size - image.shape[1]
54
+ pad_y = size - image.shape[0]
55
+ pad_l = pad_x // 2
56
+ pad_t = pad_y // 2
57
+ image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
58
+
59
+ interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
60
+ image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
61
+
62
+ image = image.astype(np.float32)
63
+ return image
64
+
65
+ def modelLoad(model_dir):
66
+ onnx_path = os.path.join(model_dir, "model.onnx")
67
+ # 実行プロバイダーをCPUのみに指定
68
+ providers = ['CPUExecutionProvider']
69
+ # InferenceSessionの作成時にプロバイダーのリストを指定
70
+ ort_session = ort.InferenceSession(onnx_path, providers=providers)
71
+ input_name = ort_session.get_inputs()[0].name
72
+
73
+ # 実際に使用されているプロバイダーを取得して表示
74
+ actual_provider = ort_session.get_providers()[0] # 使用されているプロバイダー
75
+ print(f"Using provider: {actual_provider}")
76
+
77
+ return [ort_session, input_name]
78
+
79
+ def analysis(image_path, model_dir, model):
80
+ ort_session = model[0]
81
+ input_name = model[1]
82
+
83
+ with open(os.path.join(model_dir, "selected_tags.csv"), "r", encoding="utf-8") as f:
84
+ reader = csv.reader(f)
85
+ l = [row for row in reader]
86
+ header = l[0] # tag_id,name,category,count
87
+ rows = l[1:]
88
+ assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
89
+
90
+ general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
91
+ character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
92
+
93
+ tag_freq = {}
94
+ undesired_tags = ["transparent background"]
95
+
96
+ # 画像をロードして前処理する
97
+ if image_path:
98
+ # 画像を開き、RGBA形式に変換して透過情報を保持
99
+ img = Image.open(image_path)
100
+ img = img.convert("RGBA")
101
+
102
+ # 透過部分を白色で塗りつぶすキャンバスを作成
103
+ canvas_image = Image.new('RGBA', img.size, (255, 255, 255, 255))
104
+ # 画像をキャンバスにペーストし、透過部分が白色になるように設定
105
+ canvas_image.paste(img, (0, 0), img)
106
+
107
+ # RGBAからRGBに変換し、透過部分を白色にする
108
+ image_pil = canvas_image.convert("RGB")
109
+ image_preprocessed = preprocess_image(image_pil)
110
+ image_preprocessed = np.expand_dims(image_preprocessed, axis=0)
111
+
112
+ # 推論を実行
113
+ prob = ort_session.run(None, {input_name: image_preprocessed})[0][0]
114
+ # タグを生成
115
+ combined_tags = []
116
+ general_tag_text = ""
117
+ character_tag_text = ""
118
+ remove_underscore = True
119
+ caption_separator = ", "
120
+ general_threshold = 0.35
121
+ character_threshold = 0.35
122
+
123
+ for i, p in enumerate(prob[4:]):
124
+ if i < len(general_tags) and p >= general_threshold:
125
+ tag_name = general_tags[i]
126
+ if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
127
+ tag_name = tag_name.replace("_", " ")
128
+
129
+ if tag_name not in undesired_tags:
130
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
131
+ general_tag_text += caption_separator + tag_name
132
+ combined_tags.append(tag_name)
133
+ elif i >= len(general_tags) and p >= character_threshold:
134
+ tag_name = character_tags[i - len(general_tags)]
135
+ if remove_underscore and len(tag_name) > 3:
136
+ tag_name = tag_name.replace("_", " ")
137
+
138
+ if tag_name not in undesired_tags:
139
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
140
+ character_tag_text += caption_separator + tag_name
141
+ combined_tags.append(tag_name)
142
+
143
+ # 先頭のカンマを取る
144
+ if len(general_tag_text) > 0:
145
+ general_tag_text = general_tag_text[len(caption_separator) :]
146
+ if len(character_tag_text) > 0:
147
+ character_tag_text = character_tag_text[len(caption_separator) :]
148
+ tag_text = caption_separator.join(combined_tags)
149
+ return tag_text
utils/utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import requests
4
+ from tqdm import tqdm
5
+ import shutil
6
+
7
+ from PIL import Image, ImageOps
8
+ import numpy as np
9
+ import cv2
10
+
11
+ def load_cn_model(model_dir):
12
+ folder = model_dir
13
+ file_name = 'diffusion_pytorch_model.safetensors'
14
+ url = "https://huggingface.co/kataragi/ControlNet-LineartXL/resolve/main/Katarag_lineartXL-fp16.safetensors"
15
+
16
+ file_path = os.path.join(folder, file_name)
17
+ if not os.path.exists(file_path):
18
+ response = requests.get(url, stream=True)
19
+
20
+ total_size = int(response.headers.get('content-length', 0))
21
+ with open(file_path, 'wb') as f, tqdm(
22
+ desc=file_name,
23
+ total=total_size,
24
+ unit='iB',
25
+ unit_scale=True,
26
+ unit_divisor=1024,
27
+ ) as bar:
28
+ for data in response.iter_content(chunk_size=1024):
29
+ size = f.write(data)
30
+ bar.update(size)
31
+
32
+ def load_cn_config(model_dir):
33
+ folder = model_dir
34
+ file_name = 'config.json'
35
+ file_path = os.path.join(folder, file_name)
36
+ if not os.path.exists(file_path):
37
+ config_path = os.path.join(os.getcwd(), file_name)
38
+ shutil.copy(config_path, file_path)
39
+
40
+
41
+
42
+ def resize_image_aspect_ratio(image):
43
+ # 元の画像サイズを取得
44
+ original_width, original_height = image.size
45
+
46
+ # アスペクト比を計算
47
+ aspect_ratio = original_width / original_height
48
+
49
+ # 標準のアスペクト比サイズを定義
50
+ sizes = {
51
+ 1: (1024, 1024), # 正方形
52
+ 4/3: (1152, 896), # 横長画像
53
+ 3/2: (1216, 832),
54
+ 16/9: (1344, 768),
55
+ 21/9: (1568, 672),
56
+ 3/1: (1728, 576),
57
+ 1/4: (512, 2048), # 縦長画像
58
+ 1/3: (576, 1728),
59
+ 9/16: (768, 1344),
60
+ 2/3: (832, 1216),
61
+ 3/4: (896, 1152)
62
+ }
63
+
64
+ # 最も近いアスペクト比を見つける
65
+ closest_aspect_ratio = min(sizes.keys(), key=lambda x: abs(x - aspect_ratio))
66
+ target_width, target_height = sizes[closest_aspect_ratio]
67
+
68
+ # リサイズ処理
69
+ resized_image = image.resize((target_width, target_height), Image.ANTIALIAS)
70
+
71
+ return resized_image
72
+
73
+
74
+ def base_generation(size, color):
75
+ canvas = Image.new("RGBA", size, color)
76
+ return canvas