danyloylo commited on
Commit
b06793d
1 Parent(s): 2cbe6be

Upload 15 files

Browse files
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.15.0.dev0",
4
+ "_name_or_path": "/home/josephcatrambone/ControlNet/models",
5
+ "act_fn": "silu",
6
+ "attention_head_dim": 8,
7
+ "block_out_channels": [
8
+ 320,
9
+ 640,
10
+ 1280,
11
+ 1280
12
+ ],
13
+ "class_embed_type": null,
14
+ "conditioning_embedding_out_channels": [
15
+ 16,
16
+ 32,
17
+ 96,
18
+ 256
19
+ ],
20
+ "controlnet_conditioning_channel_order": "rgb",
21
+ "cross_attention_dim": 768,
22
+ "down_block_types": [
23
+ "CrossAttnDownBlock2D",
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "DownBlock2D"
27
+ ],
28
+ "downsample_padding": 1,
29
+ "flip_sin_to_cos": true,
30
+ "freq_shift": 0,
31
+ "in_channels": 4,
32
+ "layers_per_block": 2,
33
+ "mid_block_scale_factor": 1,
34
+ "norm_eps": 1e-05,
35
+ "norm_num_groups": 32,
36
+ "num_class_embeds": null,
37
+ "only_cross_attention": false,
38
+ "projection_class_embeddings_input_dim": null,
39
+ "resnet_time_scale_shift": "default",
40
+ "upcast_attention": null,
41
+ "use_linear_projection": false
42
+ }
control_v2p_sd15_mediapipe_face.full.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2a71953d7372d5585899b44693a7532ebbf80c091108ae2b8987ca93cc2dac2
3
+ size 8601300183
control_v2p_sd15_mediapipe_face.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f2ccead3a8c0b9fbf9cad7b8eaa29834983ced916c766a92fb84db34ff29e43
3
+ size 1445239863
control_v2p_sd15_mediapipe_face.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5be501156709895f0b14a7ec76faae7cf0a105f76895252a2c69db541629628f
3
+ size 1445154814
control_v2p_sd15_mediapipe_face.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: cldm.cldm.ControlLDM
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ num_timesteps_cond: 1
7
+ log_every_t: 200
8
+ timesteps: 1000
9
+ first_stage_key: "jpg"
10
+ cond_stage_key: "txt"
11
+ control_key: "hint"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+ only_mid_control: False
20
+
21
+ control_stage_config:
22
+ target: cldm.cldm.ControlNet
23
+ params:
24
+ image_size: 32 # unused
25
+ in_channels: 4
26
+ hint_channels: 3
27
+ model_channels: 320
28
+ attention_resolutions: [ 4, 2, 1 ]
29
+ num_res_blocks: 2
30
+ channel_mult: [ 1, 2, 4, 4 ]
31
+ num_heads: 8
32
+ use_spatial_transformer: True
33
+ transformer_depth: 1
34
+ context_dim: 768
35
+ use_checkpoint: True
36
+ legacy: False
37
+
38
+ unet_config:
39
+ target: cldm.cldm.ControlledUnetModel
40
+ params:
41
+ image_size: 32 # unused
42
+ in_channels: 4
43
+ out_channels: 4
44
+ model_channels: 320
45
+ attention_resolutions: [ 4, 2, 1 ]
46
+ num_res_blocks: 2
47
+ channel_mult: [ 1, 2, 4, 4 ]
48
+ num_heads: 8
49
+ use_spatial_transformer: True
50
+ transformer_depth: 1
51
+ context_dim: 768
52
+ use_checkpoint: True
53
+ legacy: False
54
+
55
+ first_stage_config:
56
+ target: ldm.models.autoencoder.AutoencoderKL
57
+ params:
58
+ embed_dim: 4
59
+ monitor: val/rec_loss
60
+ ddconfig:
61
+ double_z: true
62
+ z_channels: 4
63
+ resolution: 256
64
+ in_channels: 3
65
+ out_ch: 3
66
+ ch: 128
67
+ ch_mult:
68
+ - 1
69
+ - 2
70
+ - 4
71
+ - 4
72
+ num_res_blocks: 2
73
+ attn_resolutions: []
74
+ dropout: 0.0
75
+ lossconfig:
76
+ target: torch.nn.Identity
77
+
78
+ cond_stage_config:
79
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f63de389f776b75bb11f10487a187573aea84f9a51debd08f314bd084e7fb362
3
+ size 1445254969
diffusion_pytorch_model.fp16.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c37b3dd41e956160909129b50f84fd938116550727b491192cbdbe6f896cd7b
3
+ size 722696633
diffusion_pytorch_model.fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fb50465b4fd7e15f0dc7df8031767e57309cfda2917082485bcf6c11bedb540
3
+ size 722598642
gradio_face2image.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Mapping
3
+
4
+ import gradio as gr
5
+ import numpy
6
+ import torch
7
+ import random
8
+ from PIL import Image
9
+
10
+ from cldm.model import create_model, load_state_dict
11
+ from cldm.ddim_hacked import DDIMSampler
12
+ from laion_face_common import generate_annotation
13
+ from share import *
14
+
15
+
16
+ model = create_model('./control_v2p_sd21_mediapipe_face.yaml').cpu()
17
+ model.load_state_dict(load_state_dict('./control_v2p_sd21_mediapipe_face.full.ckpt', location='cuda'))
18
+ model = model.cuda()
19
+ ddim_sampler = DDIMSampler(model) # ControlNet _only_ works with DDIM.
20
+
21
+
22
+ def process(input_image: Image.Image, prompt, a_prompt, n_prompt, max_faces, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta):
23
+ with torch.no_grad():
24
+ empty = generate_annotation(input_image, max_faces)
25
+ visualization = Image.fromarray(empty) # Save to help debug.
26
+
27
+ empty = numpy.moveaxis(empty, 2, 0) # h, w, c -> c, h, w
28
+ control = torch.from_numpy(empty.copy()).float().cuda() / 255.0
29
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
30
+ # control = einops.rearrange(control, 'b h w c -> b c h w').clone()
31
+
32
+ # Sanity check the dimensions.
33
+ B, C, H, W = control.shape
34
+ assert C == 3
35
+ assert B == num_samples
36
+
37
+ if seed != -1:
38
+ random.seed(seed)
39
+ os.environ['PYTHONHASHSEED'] = str(seed)
40
+ numpy.random.seed(seed)
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed(seed)
43
+ torch.backends.cudnn.deterministic = True
44
+
45
+ if config.save_memory:
46
+ model.low_vram_shift(is_diffusing=False)
47
+
48
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
49
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
50
+ shape = (4, H // 8, W // 8)
51
+
52
+ if config.save_memory:
53
+ model.low_vram_shift(is_diffusing=True)
54
+
55
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
56
+ samples, intermediates = ddim_sampler.sample(
57
+ ddim_steps,
58
+ num_samples,
59
+ shape,
60
+ cond,
61
+ verbose=False,
62
+ eta=eta,
63
+ unconditional_guidance_scale=scale,
64
+ unconditional_conditioning=un_cond
65
+ )
66
+
67
+ if config.save_memory:
68
+ model.low_vram_shift(is_diffusing=False)
69
+
70
+ x_samples = model.decode_first_stage(samples)
71
+ # x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8)
72
+ x_samples = numpy.moveaxis((x_samples * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8), 1, -1) # b, c, h, w -> b, h, w, c
73
+ results = [visualization] + [x_samples[i] for i in range(num_samples)]
74
+
75
+ return results
76
+
77
+
78
+ block = gr.Blocks().queue()
79
+ with block:
80
+ with gr.Row():
81
+ gr.Markdown("## Control Stable Diffusion with a Facial Pose")
82
+ with gr.Row():
83
+ with gr.Column():
84
+ input_image = gr.Image(source='upload', type="numpy")
85
+ prompt = gr.Textbox(label="Prompt")
86
+ run_button = gr.Button(label="Run")
87
+ with gr.Accordion("Advanced options", open=False):
88
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
89
+ max_faces = gr.Slider(label="Max Faces", minimum=1, maximum=5, value=1, step=1)
90
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
91
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
92
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
93
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
94
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
95
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
96
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
97
+ n_prompt = gr.Textbox(label="Negative Prompt",
98
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
99
+ with gr.Column():
100
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
101
+ ips = [input_image, prompt, a_prompt, n_prompt, max_faces, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta]
102
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
103
+
104
+
105
+ block.launch(server_name='0.0.0.0')
laion_face_common.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Mapping
2
+
3
+ import mediapipe as mp
4
+ import numpy
5
+ from PIL import Image
6
+
7
+
8
+ mp_drawing = mp.solutions.drawing_utils
9
+ mp_drawing_styles = mp.solutions.drawing_styles
10
+ mp_face_detection = mp.solutions.face_detection # Only for counting faces.
11
+ mp_face_mesh = mp.solutions.face_mesh
12
+ mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
13
+ mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS
14
+ mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS
15
+
16
+ DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
17
+ PoseLandmark = mp.solutions.drawing_styles.PoseLandmark
18
+
19
+ f_thick = 2
20
+ f_rad = 1
21
+ right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
22
+ right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
23
+ right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
24
+ left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
25
+ left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
26
+ left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
27
+ mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
28
+ head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
29
+
30
+ # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
31
+ face_connection_spec = {}
32
+ for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
33
+ face_connection_spec[edge] = head_draw
34
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
35
+ face_connection_spec[edge] = left_eye_draw
36
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
37
+ face_connection_spec[edge] = left_eyebrow_draw
38
+ # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
39
+ # face_connection_spec[edge] = left_iris_draw
40
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
41
+ face_connection_spec[edge] = right_eye_draw
42
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
43
+ face_connection_spec[edge] = right_eyebrow_draw
44
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
45
+ # face_connection_spec[edge] = right_iris_draw
46
+ for edge in mp_face_mesh.FACEMESH_LIPS:
47
+ face_connection_spec[edge] = mouth_draw
48
+ iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
49
+
50
+
51
+ def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):
52
+ """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
53
+ landmarks. Until our PR is merged into mediapipe, we need this separate method."""
54
+ if len(image.shape) != 3:
55
+ raise ValueError("Input image must be H,W,C.")
56
+ image_rows, image_cols, image_channels = image.shape
57
+ if image_channels != 3: # BGR channels
58
+ raise ValueError('Input image must contain three channel bgr data.')
59
+ for idx, landmark in enumerate(landmark_list.landmark):
60
+ if (
61
+ (landmark.HasField('visibility') and landmark.visibility < 0.9) or
62
+ (landmark.HasField('presence') and landmark.presence < 0.5)
63
+ ):
64
+ continue
65
+ if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
66
+ continue
67
+ image_x = int(image_cols*landmark.x)
68
+ image_y = int(image_rows*landmark.y)
69
+ draw_color = None
70
+ if isinstance(drawing_spec, Mapping):
71
+ if drawing_spec.get(idx) is None:
72
+ continue
73
+ else:
74
+ draw_color = drawing_spec[idx].color
75
+ elif isinstance(drawing_spec, DrawingSpec):
76
+ draw_color = drawing_spec.color
77
+ image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
78
+
79
+
80
+ def reverse_channels(image):
81
+ """Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB."""
82
+ # im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.
83
+ # im[:,:,::[2,1,0]] would also work but makes a copy of the data.
84
+ return image[:, :, ::-1]
85
+
86
+
87
+ def generate_annotation(
88
+ input_image: Image.Image,
89
+ max_faces: int,
90
+ min_face_size_pixels: int = 0,
91
+ return_annotation_data: bool = False
92
+ ):
93
+ """
94
+ Find up to 'max_faces' inside the provided input image.
95
+ If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many
96
+ pixels in the image.
97
+ If return_annotation_data is TRUE (default: false) then in addition to returning the 'detected face' image, three
98
+ additional parameters will be returned: faces before filtering, faces after filtering, and an annotation image.
99
+ The faces_before_filtering return value is the number of faces detected in an image with no filtering.
100
+ faces_after_filtering is the number of faces remaining after filtering small faces.
101
+
102
+ :return:
103
+ If 'return_annotation_data==True', returns (numpy array, numpy array, int, int).
104
+ If 'return_annotation_data==False' (default), returns a numpy array.
105
+ """
106
+ with mp_face_mesh.FaceMesh(
107
+ static_image_mode=True,
108
+ max_num_faces=max_faces,
109
+ refine_landmarks=True,
110
+ min_detection_confidence=0.5,
111
+ ) as facemesh:
112
+ img_rgb = numpy.asarray(input_image)
113
+ results = facemesh.process(img_rgb).multi_face_landmarks
114
+
115
+ faces_found_before_filtering = len(results)
116
+
117
+ # Filter faces that are too small
118
+ filtered_landmarks = []
119
+ for lm in results:
120
+ landmarks = lm.landmark
121
+ face_rect = [
122
+ landmarks[0].x,
123
+ landmarks[0].y,
124
+ landmarks[0].x,
125
+ landmarks[0].y,
126
+ ] # Left, up, right, down.
127
+ for i in range(len(landmarks)):
128
+ face_rect[0] = min(face_rect[0], landmarks[i].x)
129
+ face_rect[1] = min(face_rect[1], landmarks[i].y)
130
+ face_rect[2] = max(face_rect[2], landmarks[i].x)
131
+ face_rect[3] = max(face_rect[3], landmarks[i].y)
132
+ if min_face_size_pixels > 0:
133
+ face_width = abs(face_rect[2] - face_rect[0])
134
+ face_height = abs(face_rect[3] - face_rect[1])
135
+ face_width_pixels = face_width * input_image.size[0]
136
+ face_height_pixels = face_height * input_image.size[1]
137
+ face_size = min(face_width_pixels, face_height_pixels)
138
+ if face_size >= min_face_size_pixels:
139
+ filtered_landmarks.append(lm)
140
+ else:
141
+ filtered_landmarks.append(lm)
142
+
143
+ faces_remaining_after_filtering = len(filtered_landmarks)
144
+
145
+ # Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.
146
+ empty = numpy.zeros_like(img_rgb)
147
+
148
+ # Draw detected faces:
149
+ for face_landmarks in filtered_landmarks:
150
+ mp_drawing.draw_landmarks(
151
+ empty,
152
+ face_landmarks,
153
+ connections=face_connection_spec.keys(),
154
+ landmark_drawing_spec=None,
155
+ connection_drawing_spec=face_connection_spec
156
+ )
157
+ draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
158
+
159
+ # Flip BGR back to RGB.
160
+ empty = reverse_channels(empty)
161
+
162
+ # We might have to generate a composite.
163
+ if return_annotation_data:
164
+ # Note that we're copying the input image AND flipping the channels so we can draw on top of it.
165
+ annotated = reverse_channels(numpy.asarray(input_image)).copy()
166
+ for face_landmarks in filtered_landmarks:
167
+ mp_drawing.draw_landmarks(
168
+ empty,
169
+ face_landmarks,
170
+ connections=face_connection_spec.keys(),
171
+ landmark_drawing_spec=None,
172
+ connection_drawing_spec=face_connection_spec
173
+ )
174
+ draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
175
+ annotated = reverse_channels(annotated)
176
+
177
+ if not return_annotation_data:
178
+ return empty
179
+ else:
180
+ return empty, annotated, faces_found_before_filtering, faces_remaining_after_filtering
laion_face_dataset.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy
3
+ import os
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class LaionDataset(Dataset):
9
+ def __init__(self):
10
+ self.data = []
11
+ with open('./training/laion-face-processed/prompt.jsonl', 'rt') as f:
12
+ for line in f:
13
+ self.data.append(json.loads(line))
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, idx):
19
+ item = self.data[idx]
20
+
21
+ source_filename = os.path.split(item['source'])[-1]
22
+ target_filename = os.path.split(item['target'])[-1]
23
+ prompt = item['prompt']
24
+
25
+ # If prompt is "" or null, make it something simple.
26
+ if not prompt:
27
+ print(f"Image with index {idx} / {source_filename} has no text.")
28
+ prompt = "an image"
29
+
30
+ source_image = Image.open('./training/laion-face-processed/source/' + source_filename).convert("RGB")
31
+ target_image = Image.open('./training/laion-face-processed/target/' + target_filename).convert("RGB")
32
+ # Resize the image so that the minimum edge is bigger than 512x512, then crop center.
33
+ # This may cut off some parts of the face image, but in general they're smaller than 512x512 and we still want
34
+ # to cover the literal edge cases.
35
+ img_size = source_image.size
36
+ scale_factor = 512/min(img_size)
37
+ source_image = source_image.resize((1+int(img_size[0]*scale_factor), 1+int(img_size[1]*scale_factor)))
38
+ target_image = target_image.resize((1+int(img_size[0]*scale_factor), 1+int(img_size[1]*scale_factor)))
39
+ img_size = source_image.size
40
+ left_padding = (img_size[0] - 512)//2
41
+ top_padding = (img_size[1] - 512)//2
42
+ source_image = source_image.crop((left_padding, top_padding, left_padding+512, top_padding+512))
43
+ target_image = target_image.crop((left_padding, top_padding, left_padding+512, top_padding+512))
44
+
45
+ source = numpy.asarray(source_image)
46
+ target = numpy.asarray(target_image)
47
+
48
+ # Normalize source images to [0, 1].
49
+ source = source.astype(numpy.float32) / 255.0
50
+
51
+ # Normalize target images to [-1, 1].
52
+ target = (target.astype(numpy.float32) / 127.5) - 1.0
53
+
54
+ return dict(jpg=target, txt=prompt, hint=source)
55
+
tool_download_face_targets.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ """
3
+ tool_download_face_targets.py
4
+
5
+ Reads in the metadata from the LAION images and begins downloading all images.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import sys
11
+ import time
12
+ import urllib
13
+ import urllib.request
14
+ try:
15
+ from tqdm import tqdm
16
+ except ImportError:
17
+ # Wrap this method into the identity.
18
+ print("TQDM not found. Progress will be quiet without 'verbose'.")
19
+ def tqdm(x):
20
+ return x
21
+
22
+
23
+ def main(logfile_path: str, verbose: bool = False, pause_between_fetches: float = 0.0):
24
+ """Open the metadata.json file from the training directory and fetch all target images."""
25
+ # Toggle a function pointer so we don't have to check verbosity everywhere.
26
+ def out(x):
27
+ pass
28
+ if verbose:
29
+ out = print
30
+
31
+ log = open(logfile_path, 'at')
32
+ skipped_image_count = 0
33
+ errored_image_count = 0
34
+ successful_image_count = 0
35
+ if not os.path.exists("training"):
36
+ print("ERROR: training directory does not exist in the current directory.")
37
+ print("Has the archive been unzipped?")
38
+ print("Are you running from the project root?")
39
+ return 2 # BASH: No such directory.
40
+ if not os.path.exists("training/laion-face-processed/metadata.json"):
41
+ print("ERROR: metadata.json was not found in training/laion-face-processed.")
42
+ return 2
43
+ with open("training/laion-face-processed/metadata.json", 'rt') as md_in:
44
+ metadata = json.load(md_in)
45
+ # Create the directory for targets if it does not exist.
46
+ if not os.path.exists("training/laion-face-processed/target"):
47
+ os.mkdir("training/laion-face-processed/target")
48
+ for image_id, image_data in tqdm(metadata.items()):
49
+ filename = f"training/laion-face-processed/target/{image_id}.jpg"
50
+ if os.path.exists(filename):
51
+ out(f"Skipping {image_id}: file exists.")
52
+ skipped_image_count += 1
53
+ continue
54
+ if not download_file(image_data['url'], filename, verbose):
55
+ error_message = f"Problem downloading {image_id}"
56
+ out(error_message)
57
+ log.write(error_message + "\n")
58
+ log.flush() # Flush often in case we crash.
59
+ errored_image_count += 1
60
+ if pause_between_fetches > 0.0:
61
+ time.sleep(pause_between_fetches)
62
+ successful_image_count += 1
63
+ log.close()
64
+ print("Run success.")
65
+ print(f"{skipped_image_count} images skipped")
66
+ print(f"{errored_image_count} images failed to download")
67
+ print(f"{successful_image_count} images downloaded")
68
+
69
+
70
+ def download_file(url: str, output_path: str, verbose: bool = False) -> bool:
71
+ """Download the file with the given URL and save it to the specified path. Return true on success."""
72
+ try:
73
+ r = urllib.request.urlopen(url)
74
+ if not r.status == 200:
75
+ return False
76
+ with open(output_path, 'wb') as fout:
77
+ fout.write(r.read())
78
+ return True
79
+ except Exception as e:
80
+ if verbose:
81
+ print(e)
82
+ return False
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main("downloads.log", verbose="-v" in sys.argv)
tool_generate_face_poses.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ from dataclasses import dataclass, field
5
+ from glob import glob
6
+ from typing import Mapping
7
+
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+ from laion_face_common import generate_annotation
12
+
13
+
14
+ @dataclass
15
+ class RunProgress:
16
+ pending: list = field(default_factory=list)
17
+ success: list = field(default_factory=list)
18
+ skipped_size: list = field(default_factory=list)
19
+ skipped_nsfw: list = field(default_factory=list)
20
+ skipped_noface: list = field(default_factory=list)
21
+ skipped_smallface: list = field(default_factory=list)
22
+
23
+
24
+ def main(
25
+ status_filename: str,
26
+ prompt_filename: str,
27
+ input_glob: str,
28
+ output_directory: str,
29
+ annotated_output_directory: str = "",
30
+ min_image_size: int = 384,
31
+ max_image_size: int = 32766,
32
+ min_face_size_pixels: int = 64,
33
+ prompt_mapping: dict = None, # If present, maps a filename to a text prompt.
34
+ ):
35
+ status = RunProgress()
36
+
37
+ if os.path.exists(status_filename):
38
+ print("Continuing from checkpoint.")
39
+ # Restore a saved state:
40
+ status_temp = json.load(open(status_filename, 'rt'))
41
+ for k in status.__dict__.keys():
42
+ status.__setattr__(k, status_temp[k])
43
+ # Output label file:
44
+ pout = open(prompt_filename, 'at')
45
+ else:
46
+ print("Starting run.")
47
+ status = RunProgress()
48
+ status.pending = list(glob(input_glob))
49
+ # Output label file:
50
+ pout = open(prompt_filename, 'wt')
51
+ with open(status_filename, 'wt') as fout:
52
+ json.dump(status.__dict__, fout)
53
+
54
+ print(f"{len(status.pending)} images remaining")
55
+
56
+ # If we don't have a preexisting set of labels (like for ImageNet/MSCOCO), just null-fill the mapping.
57
+ # We will try on a per-image basis to see if there's a metadata .json.
58
+ if prompt_mapping is None:
59
+ prompt_mapping = dict()
60
+
61
+ step = 0
62
+ with tqdm(total=len(status.pending)) as pbar:
63
+ while len(status.pending) > 0:
64
+ full_filename = status.pending.pop()
65
+ pbar.update(1)
66
+ step += 1
67
+
68
+ if step % 100 == 0:
69
+ # Checkpoint save:
70
+ with open(status_filename, 'wt') as fout:
71
+ json.dump(status.__dict__, fout)
72
+
73
+ _fpath, fname = os.path.split(full_filename)
74
+
75
+ # Make our output filenames.
76
+ # We used to do this here so we could check if a file existed before writing, then skip it, but since we
77
+ # have a 'status' that we cache and update, we no longer have to do this check.
78
+ annotation_filename = ""
79
+ if annotated_output_directory:
80
+ annotation_filename = os.path.join(annotated_output_directory, fname)
81
+ output_filename = os.path.join(output_directory, fname)
82
+
83
+ # The LAION dataset has accompanying .json files with each image.
84
+ partial_filename, extension = os.path.splitext(full_filename)
85
+ candidate_json_fullpath = partial_filename + ".json"
86
+ image_metadata = {}
87
+ if os.path.exists(candidate_json_fullpath):
88
+ try:
89
+ image_metadata = json.load(open(candidate_json_fullpath, 'rt'))
90
+ except Exception as e:
91
+ print(e)
92
+ if "NSFW" in image_metadata:
93
+ nsfw_marker = image_metadata.get("NSFW") # This can be "", None, or other weird things.
94
+ if nsfw_marker is not None and nsfw_marker.lower() != "unlikely":
95
+ # Skip NSFW images.
96
+ status.skipped_nsfw.append(full_filename)
97
+ continue
98
+
99
+ # Try to get a prompt/caption from the metadata or the prompt mapping.
100
+ image_prompt = image_metadata.get("caption", prompt_mapping.get(fname, ""))
101
+
102
+ # Load image:
103
+ img = Image.open(full_filename).convert("RGB")
104
+ img_width = img.size[0]
105
+ img_height = img.size[1]
106
+ img_size = min(img.size[0], img.size[1])
107
+ if img_size < min_image_size or max(img_width, img_height) > max_image_size:
108
+ status.skipped_size.append(full_filename)
109
+ continue
110
+
111
+ # We re-initialize the detector every time because it has a habit of triggering weird race conditions.
112
+ empty, annotated, faces_before_filtering, faces_after_filtering = generate_annotation(
113
+ img,
114
+ max_faces=5,
115
+ min_face_size_pixels=min_face_size_pixels,
116
+ return_annotation_data=True
117
+ )
118
+ if faces_before_filtering == 0:
119
+ # Skip images with no faces.
120
+ status.skipped_noface.append(full_filename)
121
+ continue
122
+ if faces_after_filtering == 0:
123
+ # Skip images with no faces large enough
124
+ status.skipped_smallface.append(full_filename)
125
+ continue
126
+
127
+ Image.fromarray(empty).save(output_filename)
128
+ if annotation_filename:
129
+ Image.fromarray(annotated).save(annotation_filename)
130
+
131
+ # See https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md for the training file format.
132
+ # prompt.json
133
+ # a JSONL file with {"source": "source/0.jpg", "target": "target/0.jpg", "prompt": "..."}.
134
+ # a source/xxxxx.jpg or source/xxxx.png file for each of the inputs.
135
+ # a target/xxxxx.jpg for each of the outputs.
136
+ pout.write(json.dumps({
137
+ "source": os.path.join(output_directory, fname),
138
+ "target": full_filename,
139
+ "prompt": image_prompt,
140
+ }) + "\n")
141
+ pout.flush()
142
+ status.success.append(full_filename)
143
+
144
+ # We do save every 100 iterations, but it's good to save on completion, too.
145
+ with open(status_filename, 'wt') as fout:
146
+ json.dump(status.__dict__, fout)
147
+
148
+ pout.close()
149
+ print("Done!")
150
+ print(f"{len(status.success)} images added to dataset.")
151
+ print(f"{len(status.skipped_size)} images rejected for size.")
152
+ print(f"{len(status.skipped_smallface)} images rejected for having faces too small.")
153
+ print(f"{len(status.skipped_noface)} images rejected for not having faces.")
154
+ print(f"{len(status.skipped_nsfw)} images rejected for NSFW.")
155
+
156
+
157
+ if __name__ == "__main__":
158
+ if len(sys.argv) >= 3 and "-h" not in sys.argv:
159
+ prompt_jsonl = sys.argv[1]
160
+ in_glob = sys.argv[2] # Should probably be in a directory called "target/*.jpg".
161
+ output_dir = sys.argv[3] # Should probably be a directory called "source".
162
+ annotation_dir = ""
163
+ if len(sys.argv) > 4:
164
+ annotation_dir = sys.argv[4]
165
+ main("generate_face_poses_checkpoint.json", prompt_jsonl, in_glob, output_dir, annotation_dir)
166
+ else:
167
+ print(f"""Usage:
168
+ python {sys.argv[0]} prompt.jsonl target/*.jpg source/ [annotated/]
169
+ source and target are slightly confusing in this context. We are writing the image names to prompt.jsonl, so
170
+ the naming system has to be consistent with what ControlNet expects. In ControlNet, the source is the input and
171
+ target is the output. We are generating source images from targets in this application, so the second argument
172
+ should be a folder full of images. The third argument should be 'source', where the images should be places.
173
+ Optionally, an 'annotated' directory can be provided. Augmented images will be placed here.
174
+
175
+ A checkpoint file named 'generate_face_poses_checkpoint.json' will be created in the place where the script is
176
+ run. If a run is cancelled, it can be resumed from this checkpoint.
177
+
178
+ If invoking the script from bash, do not forget to enclose globs with quotes. Example usage:
179
+ `python ./tool_generate_face_poses.py ./face_prompt.jsonl "/home/josephcatrambone/training_data/data-mscoco/images/train2017/*" /home/josephcatrambone/training_data/data-mscoco/images/source_2017/`
180
+ """)
train_laion_face.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+
3
+ import pytorch_lightning as pl
4
+ from torch.utils.data import DataLoader
5
+ from laion_face_dataset import LaionDataset
6
+ from cldm.logger import ImageLogger
7
+ from cldm.model import create_model, load_state_dict
8
+
9
+
10
+ # Configs
11
+ resume_path = './models/controlnet_sd21_laion_face.ckpt'
12
+ batch_size = 4
13
+ logger_freq = 2500
14
+ learning_rate = 1e-5
15
+ sd_locked = True
16
+ only_mid_control = False
17
+
18
+
19
+ # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
20
+ model = create_model('./models/cldm_v21.yaml').cpu()
21
+ model.load_state_dict(load_state_dict(resume_path, location='cpu'))
22
+ model.learning_rate = learning_rate
23
+ model.sd_locked = sd_locked
24
+ model.only_mid_control = only_mid_control
25
+
26
+
27
+ # Save every so often:
28
+ ckpt_callback = pl.callbacks.ModelCheckpoint(
29
+ dirpath="./checkpoints/",
30
+ filename="ckpt_controlnet_sd21_{epoch}_{step}_{loss}",
31
+ monitor='train/loss_simple_step',
32
+ save_top_k=5,
33
+ every_n_train_steps=5000,
34
+ save_last=True,
35
+ )
36
+
37
+
38
+ # Misc
39
+ dataset = LaionDataset()
40
+ dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
41
+ logger = ImageLogger(batch_frequency=logger_freq)
42
+ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger, ckpt_callback])
43
+
44
+
45
+ # Train!
46
+ trainer.fit(model, dataloader)
train_laion_face_sd15.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+
3
+ import pytorch_lightning as pl
4
+ from torch.utils.data import DataLoader
5
+ from laion_face_dataset import LaionDataset
6
+ from cldm.logger import ImageLogger
7
+ from cldm.model import create_model, load_state_dict
8
+
9
+
10
+ # Configs
11
+ resume_path = './models/controlnet_sd15_laion_face.ckpt'
12
+ batch_size = 8
13
+ logger_freq = 2500
14
+ learning_rate = 1e-5
15
+ sd_locked = True
16
+ only_mid_control = False
17
+
18
+ # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
19
+ model = create_model('./models/cldm_v15.yaml').cpu()
20
+ model.load_state_dict(load_state_dict(resume_path, location='cpu'))
21
+ model.learning_rate = learning_rate
22
+ model.sd_locked = sd_locked
23
+ model.only_mid_control = only_mid_control
24
+
25
+ # Save every so often:
26
+ ckpt_callback = pl.callbacks.ModelCheckpoint(
27
+ dirpath="./checkpoints/",
28
+ filename="controlnet_sd15_laion_face_{epoch}_{step}_{loss}.ckpt",
29
+ monitor='train/loss_simple_step',
30
+ save_top_k=5,
31
+ every_n_train_steps=5000,
32
+ save_last=True,
33
+ )
34
+
35
+ # Misc
36
+ dataset = LaionDataset()
37
+ dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
38
+ logger = ImageLogger(batch_frequency=logger_freq)
39
+ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger, ckpt_callback])
40
+
41
+ # Train!
42
+ trainer.fit(model, dataloader)