File size: 8,058 Bytes
b06793d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import json
import os
import sys
from dataclasses import dataclass, field
from glob import glob
from typing import Mapping

from PIL import Image
from tqdm import tqdm

from laion_face_common import generate_annotation


@dataclass
class RunProgress:
    pending: list = field(default_factory=list)
    success: list = field(default_factory=list)
    skipped_size: list = field(default_factory=list)
    skipped_nsfw: list = field(default_factory=list)
    skipped_noface: list = field(default_factory=list)
    skipped_smallface: list = field(default_factory=list)


def main(
        status_filename: str,
        prompt_filename: str,
        input_glob: str,
        output_directory: str,
        annotated_output_directory: str = "",
        min_image_size: int = 384,
        max_image_size: int = 32766,
        min_face_size_pixels: int = 64,
        prompt_mapping: dict = None,  # If present, maps a filename to a text prompt.
):
    status = RunProgress()

    if os.path.exists(status_filename):
        print("Continuing from checkpoint.")
        # Restore a saved state:
        status_temp = json.load(open(status_filename, 'rt'))
        for k in status.__dict__.keys():
            status.__setattr__(k, status_temp[k])
        # Output label file:
        pout = open(prompt_filename, 'at')
    else:
        print("Starting run.")
        status = RunProgress()
        status.pending = list(glob(input_glob))
        # Output label file:
        pout = open(prompt_filename, 'wt')
        with open(status_filename, 'wt') as fout:
            json.dump(status.__dict__, fout)

    print(f"{len(status.pending)} images remaining")

    # If we don't have a preexisting set of labels (like for ImageNet/MSCOCO), just null-fill the mapping.
    # We will try on a per-image basis to see if there's a metadata .json.
    if prompt_mapping is None:
        prompt_mapping = dict()

    step = 0
    with tqdm(total=len(status.pending)) as pbar:
        while len(status.pending) > 0:
            full_filename = status.pending.pop()
            pbar.update(1)
            step += 1

            if step % 100 == 0:
                # Checkpoint save:
                with open(status_filename, 'wt') as fout:
                    json.dump(status.__dict__, fout)

            _fpath, fname = os.path.split(full_filename)

            # Make our output filenames.
            # We used to do this here so we could check if a file existed before writing, then skip it, but since we
            # have a 'status' that we cache and update, we no longer have to do this check.
            annotation_filename = ""
            if annotated_output_directory:
                annotation_filename = os.path.join(annotated_output_directory, fname)
            output_filename = os.path.join(output_directory, fname)

            # The LAION dataset has accompanying .json files with each image.
            partial_filename, extension = os.path.splitext(full_filename)
            candidate_json_fullpath = partial_filename + ".json"
            image_metadata = {}
            if os.path.exists(candidate_json_fullpath):
                try:
                    image_metadata = json.load(open(candidate_json_fullpath, 'rt'))
                except Exception as e:
                    print(e)
            if "NSFW" in image_metadata:
                nsfw_marker = image_metadata.get("NSFW")  # This can be "", None, or other weird things.
                if nsfw_marker is not None and nsfw_marker.lower() != "unlikely":
                    # Skip NSFW images.
                    status.skipped_nsfw.append(full_filename)
                    continue

            # Try to get a prompt/caption from the metadata or the prompt mapping.
            image_prompt = image_metadata.get("caption", prompt_mapping.get(fname, ""))

            # Load image:
            img = Image.open(full_filename).convert("RGB")
            img_width = img.size[0]
            img_height = img.size[1]
            img_size = min(img.size[0], img.size[1])
            if img_size < min_image_size or max(img_width, img_height) > max_image_size:
                status.skipped_size.append(full_filename)
                continue

            # We re-initialize the detector every time because it has a habit of triggering weird race conditions.
            empty, annotated, faces_before_filtering, faces_after_filtering = generate_annotation(
                img,
                max_faces=5,
                min_face_size_pixels=min_face_size_pixels,
                return_annotation_data=True
            )
            if faces_before_filtering == 0:
                # Skip images with no faces.
                status.skipped_noface.append(full_filename)
                continue
            if faces_after_filtering == 0:
                # Skip images with no faces large enough
                status.skipped_smallface.append(full_filename)
                continue

            Image.fromarray(empty).save(output_filename)
            if annotation_filename:
                Image.fromarray(annotated).save(annotation_filename)

            # See https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md for the training file format.
            # prompt.json
            # a JSONL file with {"source": "source/0.jpg", "target": "target/0.jpg", "prompt": "..."}.
            # a source/xxxxx.jpg or source/xxxx.png file for each of the inputs.
            # a target/xxxxx.jpg for each of the outputs.
            pout.write(json.dumps({
                "source": os.path.join(output_directory, fname),
                "target": full_filename,
                "prompt": image_prompt,
            }) + "\n")
            pout.flush()
            status.success.append(full_filename)

    # We do save every 100 iterations, but it's good to save on completion, too.
    with open(status_filename, 'wt') as fout:
        json.dump(status.__dict__, fout)

    pout.close()
    print("Done!")
    print(f"{len(status.success)} images added to dataset.")
    print(f"{len(status.skipped_size)} images rejected for size.")
    print(f"{len(status.skipped_smallface)} images rejected for having faces too small.")
    print(f"{len(status.skipped_noface)} images rejected for not having faces.")
    print(f"{len(status.skipped_nsfw)} images rejected for NSFW.")


if __name__ == "__main__":
    if len(sys.argv) >= 3 and "-h" not in sys.argv:
        prompt_jsonl = sys.argv[1]
        in_glob = sys.argv[2]  # Should probably be in a directory called "target/*.jpg".
        output_dir = sys.argv[3]  # Should probably be a directory called "source".
        annotation_dir = ""
        if len(sys.argv) > 4:
            annotation_dir = sys.argv[4]
        main("generate_face_poses_checkpoint.json", prompt_jsonl, in_glob, output_dir, annotation_dir)
    else:
        print(f"""Usage:
        python {sys.argv[0]} prompt.jsonl target/*.jpg source/ [annotated/]
        source and target are slightly confusing in this context.  We are writing the image names to prompt.jsonl, so 
        the naming system has to be consistent with what ControlNet expects.  In ControlNet, the source is the input and
        target is the output.  We are generating source images from targets in this application, so the second argument 
        should be a folder full of images.  The third argument should be 'source', where the images should be places.
        Optionally, an 'annotated' directory can be provided.  Augmented images will be placed here.

        A checkpoint file named 'generate_face_poses_checkpoint.json' will be created in the place where the script is 
        run.  If a run is cancelled, it can be resumed from this checkpoint.

        If invoking the script from bash, do not forget to enclose globs with quotes.  Example usage:
        `python ./tool_generate_face_poses.py ./face_prompt.jsonl "/home/josephcatrambone/training_data/data-mscoco/images/train2017/*" /home/josephcatrambone/training_data/data-mscoco/images/source_2017/`
        """)