diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..31bd305a8f1063d428f9b24eecea8e70c7d5f97d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +cog_sdxl/tests/assets/ filter=lfs diff=lfs merge=lfs -text diff --git a/cog_sdxl/.dockerignore b/cog_sdxl/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..0b70e8abb2ab49861db5d4f035c31d628152fbb8 --- /dev/null +++ b/cog_sdxl/.dockerignore @@ -0,0 +1,35 @@ +sdxl-cache/ +refiner-cache/ +safety-cache/ +trained-model/ +*.png +cache/ +checkpoint/ +training_out/ +dreambooth/ +lora/ +ttemp/ +.git/ +cog_class_data/ +dataset/ +training_data/ +temp/ +temp_in/ +cog_instance_data/ +example_datasets/ +trained_model.tar +zeke_data.tar +data.tar +zeke.zip +sketch-mountains-input.jpeg +training_out* +weights +inference_* +trained-model +*.zip +tmp/ +blip-cache/ +clipseg-cache/ +swin2sr-cache/ +weights-cache/ +tests/ \ No newline at end of file diff --git a/cog_sdxl/.gitignore b/cog_sdxl/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9a958903ea661a77948595223b0a836ff3eceb5a --- /dev/null +++ b/cog_sdxl/.gitignore @@ -0,0 +1,23 @@ + +refiner-cache +sdxl-cache +safety-cache +trained-model +temp +temp_in +cache +.cog +__pycache__ +wandb +ft* +*.ipynb +dataset +training_data +training_out +output* +training_out* +trained_model.tar +checkpoint* +weights +__*.zip +**-cache \ No newline at end of file diff --git a/cog_sdxl/LICENSE b/cog_sdxl/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b9afc029c74e9ca9ddd699ecf4f19eb6007420ed --- /dev/null +++ b/cog_sdxl/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023, Replicate, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/cog_sdxl/README.md b/cog_sdxl/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cbb757d84a9a638c98bcb341e4a3df7a382d0398 --- /dev/null +++ b/cog_sdxl/README.md @@ -0,0 +1,41 @@ +# Cog-SDXL + +[![Replicate demo and cloud API](https://replicate.com/stability-ai/sdxl/badge)](https://replicate.com/stability-ai/sdxl) + +This is an implementation of Stability AI's [SDXL](https://github.com/Stability-AI/generative-models) as a [Cog](https://github.com/replicate/cog) model. + +## Development + +Follow the [model pushing guide](https://replicate.com/docs/guides/push-a-model) to push your own fork of SDXL to [Replicate](https://replicate.com). + +## Basic Usage + +for prediction, + +```bash +cog predict -i prompt="a photo of TOK" +``` + +```bash +cog train -i input_images=@example_datasets/__data.zip -i use_face_detection_instead=True +``` + +```bash +cog run -p 5000 python -m cog.server.http +``` + +## Update notes + +**2023-08-17** +* ROI problem is fixed. +* Now BLIP caption_prefix does not interfere with BLIP captioner. + + +**2023-08-12** +* Input types are inferred from input name extensions, or from the `input_images_filetype` argument +* Preprocssing are now done with fp16, and if no mask is found, the model will use the whole image + +**2023-08-11** +* Default to 768x768 resolution training +* Rank as argument now, default to 32 +* Now uses Swin2SR `caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr` as default, and will upscale + downscale to 768x768 diff --git a/cog_sdxl/cog.yaml b/cog_sdxl/cog.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f981c2f215d4c2400f2fdb4ec8e119911f1cb814 --- /dev/null +++ b/cog_sdxl/cog.yaml @@ -0,0 +1,33 @@ +# Configuration for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + gpu: true + cuda: "11.8" + python_version: "3.9" + system_packages: + - "libgl1-mesa-glx" + - "ffmpeg" + - "libsm6" + - "libxext6" + - "wget" + python_packages: + - "diffusers<=0.25" + - "torch==2.0.1" + - "transformers==4.31.0" + - "invisible-watermark==0.2.0" + - "accelerate==0.21.0" + - "pandas==2.0.3" + - "torchvision==0.15.2" + - "numpy==1.25.1" + - "pandas==2.0.3" + - "fire==0.5.0" + - "opencv-python>=4.1.0.25" + - "mediapipe==0.10.2" + + run: + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)" && chmod +x /usr/local/bin/pget + - wget http://thegiflibrary.tumblr.com/post/11565547760 -O face_landmarker_v2_with_blendshapes.task -q https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task + +predict: "predict.py:Predictor" +train: "train.py:train" diff --git a/cog_sdxl/dataset_and_utils.py b/cog_sdxl/dataset_and_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d82a9cfc2cb9f02fccb5cae3971645b6e33d853b --- /dev/null +++ b/cog_sdxl/dataset_and_utils.py @@ -0,0 +1,421 @@ +import os +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import PIL +import torch +import torch.utils.checkpoint +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from PIL import Image +from safetensors import safe_open +from safetensors.torch import save_file +from torch.utils.data import Dataset +from transformers import AutoTokenizer, PretrainedConfig + + +def prepare_image( + pil_image: PIL.Image.Image, w: int = 512, h: int = 512 +) -> torch.Tensor: + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +def prepare_mask( + pil_image: PIL.Image.Image, w: int = 512, h: int = 512 +) -> torch.Tensor: + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("L")) + arr = arr.astype(np.float32) / 255.0 + arr = np.expand_dims(arr, 0) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class PreprocessedDataset(Dataset): + def __init__( + self, + csv_path: str, + tokenizer_1, + tokenizer_2, + vae_encoder, + text_encoder_1=None, + text_encoder_2=None, + do_cache: bool = False, + size: int = 512, + text_dropout: float = 0.0, + scale_vae_latents: bool = True, + substitute_caption_map: Dict[str, str] = {}, + ): + super().__init__() + + self.data = pd.read_csv(csv_path) + self.csv_path = csv_path + + self.caption = self.data["caption"] + # make it lowercase + self.caption = self.caption.str.lower() + for key, value in substitute_caption_map.items(): + self.caption = self.caption.str.replace(key.lower(), value) + + self.image_path = self.data["image_path"] + + if "mask_path" not in self.data.columns: + self.mask_path = None + else: + self.mask_path = self.data["mask_path"] + + if text_encoder_1 is None: + self.return_text_embeddings = False + else: + self.text_encoder_1 = text_encoder_1 + self.text_encoder_2 = text_encoder_2 + self.return_text_embeddings = True + assert ( + NotImplementedError + ), "Preprocessing Text Encoder is not implemented yet" + + self.tokenizer_1 = tokenizer_1 + self.tokenizer_2 = tokenizer_2 + + self.vae_encoder = vae_encoder + self.scale_vae_latents = scale_vae_latents + self.text_dropout = text_dropout + + self.size = size + + if do_cache: + self.vae_latents = [] + self.tokens_tuple = [] + self.masks = [] + + self.do_cache = True + + print("Captions to train on: ") + for idx in range(len(self.data)): + token, vae_latent, mask = self._process(idx) + self.vae_latents.append(vae_latent) + self.tokens_tuple.append(token) + self.masks.append(mask) + + del self.vae_encoder + + else: + self.do_cache = False + + @torch.no_grad() + def _process( + self, idx: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + image_path = self.image_path[idx] + image_path = os.path.join(os.path.dirname(self.csv_path), image_path) + + image = PIL.Image.open(image_path).convert("RGB") + image = prepare_image(image, self.size, self.size).to( + dtype=self.vae_encoder.dtype, device=self.vae_encoder.device + ) + + caption = self.caption[idx] + + print(caption) + + # tokenizer_1 + ti1 = self.tokenizer_1( + caption, + padding="max_length", + max_length=77, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).input_ids + + ti2 = self.tokenizer_2( + caption, + padding="max_length", + max_length=77, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).input_ids + + vae_latent = self.vae_encoder.encode(image).latent_dist.sample() + + if self.scale_vae_latents: + vae_latent = vae_latent * self.vae_encoder.config.scaling_factor + + if self.mask_path is None: + mask = torch.ones_like( + vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device + ) + + else: + mask_path = self.mask_path[idx] + mask_path = os.path.join(os.path.dirname(self.csv_path), mask_path) + + mask = PIL.Image.open(mask_path) + mask = prepare_mask(mask, self.size, self.size).to( + dtype=self.vae_encoder.dtype, device=self.vae_encoder.device + ) + + mask = torch.nn.functional.interpolate( + mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest" + ) + mask = mask.repeat(1, vae_latent.shape[1], 1, 1) + + assert len(mask.shape) == 4 and len(vae_latent.shape) == 4 + + return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze() + + def __len__(self) -> int: + return len(self.data) + + def atidx( + self, idx: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + if self.do_cache: + return self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx] + else: + return self._process(idx) + + def __getitem__( + self, idx: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + token, vae_latent, mask = self.atidx(idx) + return token, vae_latent, mask + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def load_models(pretrained_model_name_or_path, revision, device, weight_dtype): + tokenizer_one = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + revision=revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=revision, + use_fast=False, + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder="scheduler" + ) + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + pretrained_model_name_or_path, revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + pretrained_model_name_or_path, revision, subfolder="text_encoder_2" + ) + text_encoder_one = text_encoder_cls_one.from_pretrained( + pretrained_model_name_or_path, subfolder="text_encoder", revision=revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision + ) + + vae = AutoencoderKL.from_pretrained( + pretrained_model_name_or_path, subfolder="vae", revision=revision + ) + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, subfolder="unet", revision=revision + ) + + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + unet.to(device, dtype=weight_dtype) + vae.to(device, dtype=torch.float32) + text_encoder_one.to(device, dtype=weight_dtype) + text_encoder_two.to(device, dtype=weight_dtype) + + return ( + tokenizer_one, + tokenizer_two, + noise_scheduler, + text_encoder_one, + text_encoder_two, + vae, + unet, + ) + + +def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: + """ + Returns: + a state dict containing just the attention processor parameters. + """ + attn_processors = unet.attn_processors + + attn_processors_state_dict = {} + + for attn_processor_key, attn_processor in attn_processors.items(): + for parameter_key, parameter in attn_processor.state_dict().items(): + attn_processors_state_dict[ + f"{attn_processor_key}.{parameter_key}" + ] = parameter + + return attn_processors_state_dict + + +class TokenEmbeddingsHandler: + def __init__(self, text_encoders, tokenizers): + self.text_encoders = text_encoders + self.tokenizers = tokenizers + + self.train_ids: Optional[torch.Tensor] = None + self.inserting_toks: Optional[List[str]] = None + self.embeddings_settings = {} + + def initialize_new_tokens(self, inserting_toks: List[str]): + idx = 0 + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + assert isinstance( + inserting_toks, list + ), "inserting_toks should be a list of strings." + assert all( + isinstance(tok, str) for tok in inserting_toks + ), "All elements in inserting_toks should be strings." + + self.inserting_toks = inserting_toks + special_tokens_dict = {"additional_special_tokens": self.inserting_toks} + tokenizer.add_special_tokens(special_tokens_dict) + text_encoder.resize_token_embeddings(len(tokenizer)) + + self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + + # random initialization of new tokens + + std_token_embedding = ( + text_encoder.text_model.embeddings.token_embedding.weight.data.std() + ) + + print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}") + + text_encoder.text_model.embeddings.token_embedding.weight.data[ + self.train_ids + ] = ( + torch.randn( + len(self.train_ids), text_encoder.text_model.config.hidden_size + ) + .to(device=self.device) + .to(dtype=self.dtype) + * std_token_embedding + ) + self.embeddings_settings[ + f"original_embeddings_{idx}" + ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding + + inu = torch.ones((len(tokenizer),), dtype=torch.bool) + inu[self.train_ids] = False + + self.embeddings_settings[f"index_no_updates_{idx}"] = inu + + print(self.embeddings_settings[f"index_no_updates_{idx}"].shape) + + idx += 1 + + def save_embeddings(self, file_path: str): + assert ( + self.train_ids is not None + ), "Initialize new tokens before saving embeddings." + tensors = {} + for idx, text_encoder in enumerate(self.text_encoders): + assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[ + 0 + ] == len(self.tokenizers[0]), "Tokenizers should be the same." + new_token_embeddings = ( + text_encoder.text_model.embeddings.token_embedding.weight.data[ + self.train_ids + ] + ) + tensors[f"text_encoders_{idx}"] = new_token_embeddings + + save_file(tensors, file_path) + + @property + def dtype(self): + return self.text_encoders[0].dtype + + @property + def device(self): + return self.text_encoders[0].device + + def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder): + # Assuming new tokens are of the format + self.inserting_toks = [f"" for i in range(loaded_embeddings.shape[0])] + special_tokens_dict = {"additional_special_tokens": self.inserting_toks} + tokenizer.add_special_tokens(special_tokens_dict) + text_encoder.resize_token_embeddings(len(tokenizer)) + + self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + assert self.train_ids is not None, "New tokens could not be converted to IDs." + text_encoder.text_model.embeddings.token_embedding.weight.data[ + self.train_ids + ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype) + + @torch.no_grad() + def retract_embeddings(self): + for idx, text_encoder in enumerate(self.text_encoders): + index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] + text_encoder.text_model.embeddings.token_embedding.weight.data[ + index_no_updates + ] = ( + self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] + .to(device=text_encoder.device) + .to(dtype=text_encoder.dtype) + ) + + # for the parts that were updated, we need to normalize them + # to have the same std as before + std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] + + index_updates = ~index_no_updates + new_embeddings = ( + text_encoder.text_model.embeddings.token_embedding.weight.data[ + index_updates + ] + ) + off_ratio = std_token_embedding / new_embeddings.std() + + new_embeddings = new_embeddings * (off_ratio**0.1) + text_encoder.text_model.embeddings.token_embedding.weight.data[ + index_updates + ] = new_embeddings + + def load_embeddings(self, file_path: str): + with safe_open(file_path, framework="pt", device=self.device.type) as f: + for idx in range(len(self.text_encoders)): + text_encoder = self.text_encoders[idx] + tokenizer = self.tokenizers[idx] + + loaded_embeddings = f.get_tensor(f"text_encoders_{idx}") + self._load_embeddings(loaded_embeddings, tokenizer, text_encoder) diff --git a/cog_sdxl/example_datasets/README.md b/cog_sdxl/example_datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b206063e7178fe5da8496b02b44ccd9a49e03311 --- /dev/null +++ b/cog_sdxl/example_datasets/README.md @@ -0,0 +1,3 @@ +## Example Datasets + +This folder contains three example datasets that were used to tune SDXL using the Replicate API, along with (at the top level) example outputs generated from those datasets. \ No newline at end of file diff --git a/cog_sdxl/example_datasets/kiriko.png b/cog_sdxl/example_datasets/kiriko.png new file mode 100644 index 0000000000000000000000000000000000000000..7e28b85a30f2b2d0bfa4a7d50457fb07f4b9a16e --- /dev/null +++ b/cog_sdxl/example_datasets/kiriko.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d9861dc28bf9fd0b33992f927630f1ade740017158be76f0afa385008b0775a +size 1140952 diff --git a/cog_sdxl/example_datasets/kiriko/0.src.jpg b/cog_sdxl/example_datasets/kiriko/0.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f0b6a345f183c1ba9ea64cffa23043becfa542d2 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/0.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/1.src.jpg b/cog_sdxl/example_datasets/kiriko/1.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3819fef2086f5405256fe7594c52a53b631fcc8a Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/1.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/10.src.jpg b/cog_sdxl/example_datasets/kiriko/10.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9150513ba47d2ca225f7be87b79c3af26d6bc364 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/10.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/11.src.jpg b/cog_sdxl/example_datasets/kiriko/11.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c9c84e3d94b10a79c4bb7d87b0f7bd2738c35df1 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/11.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/12.src.jpg b/cog_sdxl/example_datasets/kiriko/12.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cdb4b4f711c5d3d2d280fe2f76b85b29c3750258 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/12.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/2.src.jpg b/cog_sdxl/example_datasets/kiriko/2.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6276de2478875f03dc6fd347ae28994e03e5536 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/2.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/3.src.jpg b/cog_sdxl/example_datasets/kiriko/3.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..77e7f6300185ca2a3871ba2f51836bf80abae1fd Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/3.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/4.src.jpg b/cog_sdxl/example_datasets/kiriko/4.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..58b472f18f034cb1685e39fb760bf31d9e55d1b7 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/4.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/5.src.jpg b/cog_sdxl/example_datasets/kiriko/5.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e3e20cbf0ada99238515f84c68bd94ba80fad020 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/5.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/6.src.jpg b/cog_sdxl/example_datasets/kiriko/6.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6d7de50b72aee5fa76ab96af016bd35c3cb31618 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/6.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/7.src.jpg b/cog_sdxl/example_datasets/kiriko/7.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..03f8461de36760f10289ef7c06526937e5d316a9 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/7.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/8.src.jpg b/cog_sdxl/example_datasets/kiriko/8.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ee549daa38613b477e2f3dc7034752989f3080d2 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/8.src.jpg differ diff --git a/cog_sdxl/example_datasets/kiriko/9.src.jpg b/cog_sdxl/example_datasets/kiriko/9.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..041c2839184dbf6607e84a06bb37115104dbcfd4 Binary files /dev/null and b/cog_sdxl/example_datasets/kiriko/9.src.jpg differ diff --git a/cog_sdxl/example_datasets/monster.png b/cog_sdxl/example_datasets/monster.png new file mode 100644 index 0000000000000000000000000000000000000000..f9382f6a1beff6ce1c66167e9298677e332418d6 --- /dev/null +++ b/cog_sdxl/example_datasets/monster.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40ed43418f51b843dbf3f4d2ef2d3a833840c9c362868711ae580f835c5015bc +size 515313 diff --git a/cog_sdxl/example_datasets/monster/caption.csv b/cog_sdxl/example_datasets/monster/caption.csv new file mode 100644 index 0000000000000000000000000000000000000000..435ca7f69a3489c81d72241e1db525eb2f15f8c1 --- /dev/null +++ b/cog_sdxl/example_datasets/monster/caption.csv @@ -0,0 +1,6 @@ +caption,image_file +a TOK on a windowsill,monstertoy (1).jpg +a photo of smiling TOK in an office,monstertoy (2).jpg +a photo of TOK sitting by a window,monstertoy (3).jpg +a photo of TOK on a car,monstertoy (4).jpg +a photo of TOK smiling on the ground,monstertoy (5).jpg \ No newline at end of file diff --git a/cog_sdxl/example_datasets/monster/monstertoy (1).jpg b/cog_sdxl/example_datasets/monster/monstertoy (1).jpg new file mode 100644 index 0000000000000000000000000000000000000000..b4dfad832bbb3885156ee9282b2caf72b4f4cdbb Binary files /dev/null and b/cog_sdxl/example_datasets/monster/monstertoy (1).jpg differ diff --git a/cog_sdxl/example_datasets/monster/monstertoy (2).jpg b/cog_sdxl/example_datasets/monster/monstertoy (2).jpg new file mode 100644 index 0000000000000000000000000000000000000000..3c82fd929b4b8e17fb9e3eaa660d4cf09f5e3b6a Binary files /dev/null and b/cog_sdxl/example_datasets/monster/monstertoy (2).jpg differ diff --git a/cog_sdxl/example_datasets/monster/monstertoy (3).jpg b/cog_sdxl/example_datasets/monster/monstertoy (3).jpg new file mode 100644 index 0000000000000000000000000000000000000000..109aec5ee6f0b6705b0ac45c78514ffa22464576 Binary files /dev/null and b/cog_sdxl/example_datasets/monster/monstertoy (3).jpg differ diff --git a/cog_sdxl/example_datasets/monster/monstertoy (4).jpg b/cog_sdxl/example_datasets/monster/monstertoy (4).jpg new file mode 100644 index 0000000000000000000000000000000000000000..64b9ef50224719896259497fe3eba693f759a854 Binary files /dev/null and b/cog_sdxl/example_datasets/monster/monstertoy (4).jpg differ diff --git a/cog_sdxl/example_datasets/monster/monstertoy (5).jpg b/cog_sdxl/example_datasets/monster/monstertoy (5).jpg new file mode 100644 index 0000000000000000000000000000000000000000..c5194b5d306edbbb59fcdc82dcba15036dfe2beb Binary files /dev/null and b/cog_sdxl/example_datasets/monster/monstertoy (5).jpg differ diff --git a/cog_sdxl/example_datasets/monster_uni.png b/cog_sdxl/example_datasets/monster_uni.png new file mode 100644 index 0000000000000000000000000000000000000000..e60ea4ad3974fc3231130fe586a3129e4365c6a1 --- /dev/null +++ b/cog_sdxl/example_datasets/monster_uni.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98bf9d0cbef77d7cc5a541940a32a02a9ea49d8122f9722401c9b3c7956aa47a +size 1713313 diff --git a/cog_sdxl/example_datasets/zeke.zip b/cog_sdxl/example_datasets/zeke.zip new file mode 100644 index 0000000000000000000000000000000000000000..e969aa96ade6b6ffe2660158725af203fd57c5ff --- /dev/null +++ b/cog_sdxl/example_datasets/zeke.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64d655ee118eec386272a15c8e3c2522bc40155cd0f39f451596f7800df403e6 +size 860587 diff --git a/cog_sdxl/example_datasets/zeke/0.src.jpg b/cog_sdxl/example_datasets/zeke/0.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c4e13507fa28d3902836900ec6002fc9e7ffd724 Binary files /dev/null and b/cog_sdxl/example_datasets/zeke/0.src.jpg differ diff --git a/cog_sdxl/example_datasets/zeke/1.src.jpg b/cog_sdxl/example_datasets/zeke/1.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2a8ec5aea9e1f7d82328f6f554fb6d0773528d68 Binary files /dev/null and b/cog_sdxl/example_datasets/zeke/1.src.jpg differ diff --git a/cog_sdxl/example_datasets/zeke/2.src.jpg b/cog_sdxl/example_datasets/zeke/2.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2d746c2a7f9cfcb5dcea55dd99a205db462814d3 Binary files /dev/null and b/cog_sdxl/example_datasets/zeke/2.src.jpg differ diff --git a/cog_sdxl/example_datasets/zeke/3.src.jpg b/cog_sdxl/example_datasets/zeke/3.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cac55d0ae58a50435b756c2e196ab31c573d5ed3 Binary files /dev/null and b/cog_sdxl/example_datasets/zeke/3.src.jpg differ diff --git a/cog_sdxl/example_datasets/zeke/4.src.jpg b/cog_sdxl/example_datasets/zeke/4.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d122efaacf19766801fb24b9029aaa7bd9c344be Binary files /dev/null and b/cog_sdxl/example_datasets/zeke/4.src.jpg differ diff --git a/cog_sdxl/example_datasets/zeke/5.src.jpg b/cog_sdxl/example_datasets/zeke/5.src.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1ca60cfb728af3e2101dc08d03a640d91a82aa41 Binary files /dev/null and b/cog_sdxl/example_datasets/zeke/5.src.jpg differ diff --git a/cog_sdxl/example_datasets/zeke_unicorn.png b/cog_sdxl/example_datasets/zeke_unicorn.png new file mode 100644 index 0000000000000000000000000000000000000000..2a1c9b81c53877753993a3dd4983161f41988b1f --- /dev/null +++ b/cog_sdxl/example_datasets/zeke_unicorn.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59339a736d96dde6f8459ac1f357ed63707e5f5eb50fea3616a64eaaf2586416 +size 1647360 diff --git a/cog_sdxl/feature-extractor/preprocessor_config.json b/cog_sdxl/feature-extractor/preprocessor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..4c3bb8850ded75a46d33226fabf32fa8490c2bd3 --- /dev/null +++ b/cog_sdxl/feature-extractor/preprocessor_config.json @@ -0,0 +1,20 @@ +{ + "crop_size": 224, + "do_center_crop": true, + "do_convert_rgb": true, + "do_normalize": true, + "do_resize": true, + "feature_extractor_type": "CLIPFeatureExtractor", + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "resample": 3, + "size": 224 + } \ No newline at end of file diff --git a/cog_sdxl/no_init.py b/cog_sdxl/no_init.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5b004783662911f21abcdd16601eb00c103a64 --- /dev/null +++ b/cog_sdxl/no_init.py @@ -0,0 +1,121 @@ +import contextlib +import contextvars +import threading +from typing import ( + Callable, + ContextManager, + NamedTuple, + Optional, + TypeVar, + Union, +) + +import torch + +__all__ = ["no_init_or_tensor"] + + +Model = TypeVar("Model") + + +def no_init_or_tensor( + loading_code: Optional[Callable[..., Model]] = None +) -> Union[Model, ContextManager]: + """ + Suppress the initialization of weights while loading a model. + + Can either directly be passed a callable containing model-loading code, + which will be evaluated with weight initialization suppressed, + or used as a context manager around arbitrary model-loading code. + + Args: + loading_code: Either a callable to evaluate + with model weight initialization suppressed, + or None (the default) to use as a context manager. + + Returns: + The return value of `loading_code`, if `loading_code` is callable. + + Otherwise, if `loading_code` is None, returns a context manager + to be used in a `with`-statement. + + Examples: + As a context manager:: + + from transformers import AutoConfig, AutoModelForCausalLM + config = AutoConfig("EleutherAI/gpt-j-6B") + with no_init_or_tensor(): + model = AutoModelForCausalLM.from_config(config) + + Or, directly passing a callable:: + + from transformers import AutoConfig, AutoModelForCausalLM + config = AutoConfig("EleutherAI/gpt-j-6B") + model = no_init_or_tensor(lambda: AutoModelForCausalLM.from_config(config)) + """ + if loading_code is None: + return _NoInitOrTensorImpl.context_manager() + elif callable(loading_code): + with _NoInitOrTensorImpl.context_manager(): + return loading_code() + else: + raise TypeError( + "no_init_or_tensor() expected a callable to evaluate," + " or None if being used as a context manager;" + f' got an object of type "{type(loading_code).__name__}" instead.' + ) + + +class _NoInitOrTensorImpl: + # Implementation of the thread-safe, async-safe, re-entrant context manager + # version of no_init_or_tensor(). + # This class essentially acts as a namespace. + # It is not instantiable, because modifications to torch functions + # inherently affect the global scope, and thus there is no worthwhile data + # to store in the class instance scope. + _MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm) + _MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES) + _ORIGINAL_EMPTY = torch.empty + + is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False) + _count_active: int = 0 + _count_active_lock = threading.Lock() + + @classmethod + @contextlib.contextmanager + def context_manager(cls): + if cls.is_active.get(): + yield + return + + with cls._count_active_lock: + cls._count_active += 1 + if cls._count_active == 1: + for mod in cls._MODULES: + mod.reset_parameters = cls._disable(mod.reset_parameters) + # When torch.empty is called, make it map to meta device by replacing + # the device in kwargs. + torch.empty = cls._ORIGINAL_EMPTY + reset_token = cls.is_active.set(True) + + try: + yield + finally: + cls.is_active.reset(reset_token) + with cls._count_active_lock: + cls._count_active -= 1 + if cls._count_active == 0: + torch.empty = cls._ORIGINAL_EMPTY + for mod, original in cls._MODULE_ORIGINALS: + mod.reset_parameters = original + + @staticmethod + def _disable(func): + def wrapper(*args, **kwargs): + # Behaves as normal except in an active context + if not _NoInitOrTensorImpl.is_active.get(): + return func(*args, **kwargs) + + return wrapper + + __init__ = None diff --git a/cog_sdxl/predict.py b/cog_sdxl/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..8268092729b1e00cca455da05ca6f8f4221b15dd --- /dev/null +++ b/cog_sdxl/predict.py @@ -0,0 +1,462 @@ +import hashlib +import json +import os +import shutil +import subprocess +import time +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from weights import WeightsDownloadCache + +import numpy as np +import torch +from cog import BasePredictor, Input, Path +from diffusers import ( + DDIMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + PNDMScheduler, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, +) +from diffusers.models.attention_processor import LoRAAttnProcessor2_0 +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from diffusers.utils import load_image +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import CLIPImageProcessor + +from dataset_and_utils import TokenEmbeddingsHandler + +SDXL_MODEL_CACHE = "./sdxl-cache" +REFINER_MODEL_CACHE = "./refiner-cache" +SAFETY_CACHE = "./safety-cache" +FEATURE_EXTRACTOR = "./feature-extractor" +SDXL_URL = "https://weights.replicate.delivery/default/sdxl/sdxl-vae-upcast-fix.tar" +REFINER_URL = ( + "https://weights.replicate.delivery/default/sdxl/refiner-no-vae-no-encoder-1.0.tar" +) +SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar" + + +class KarrasDPM: + def from_config(config): + return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True) + + +SCHEDULERS = { + "DDIM": DDIMScheduler, + "DPMSolverMultistep": DPMSolverMultistepScheduler, + "HeunDiscrete": HeunDiscreteScheduler, + "KarrasDPM": KarrasDPM, + "K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler, + "K_EULER": EulerDiscreteScheduler, + "PNDM": PNDMScheduler, +} + + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + + +class Predictor(BasePredictor): + def load_trained_weights(self, weights, pipe): + from no_init import no_init_or_tensor + + # weights can be a URLPath, which behaves in unexpected ways + weights = str(weights) + if self.tuned_weights == weights: + print("skipping loading .. weights already loaded") + return + + # predictions can be cancelled while in this function, which + # interrupts this finishing. To protect against odd states we + # set tuned_weights to a value that lets the next prediction + # know if it should try to load weights or if loading completed + self.tuned_weights = 'loading' + + local_weights_cache = self.weights_cache.ensure(weights) + + # load UNET + print("Loading fine-tuned model") + self.is_lora = False + + maybe_unet_path = os.path.join(local_weights_cache, "unet.safetensors") + if not os.path.exists(maybe_unet_path): + print("Does not have Unet. assume we are using LoRA") + self.is_lora = True + + if not self.is_lora: + print("Loading Unet") + + new_unet_params = load_file( + os.path.join(local_weights_cache, "unet.safetensors") + ) + # this should return _IncompatibleKeys(missing_keys=[...], unexpected_keys=[]) + pipe.unet.load_state_dict(new_unet_params, strict=False) + + else: + print("Loading Unet LoRA") + + unet = pipe.unet + + tensors = load_file(os.path.join(local_weights_cache, "lora.safetensors")) + + unet_lora_attn_procs = {} + name_rank_map = {} + for tk, tv in tensors.items(): + # up is N, d + tensors[tk] = tv.half() + if tk.endswith("up.weight"): + proc_name = ".".join(tk.split(".")[:-3]) + r = tv.shape[1] + name_rank_map[proc_name] = r + + for name, attn_processor in unet.attn_processors.items(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[ + block_id + ] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + with no_init_or_tensor(): + module = LoRAAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=name_rank_map[name], + ).half() + unet_lora_attn_procs[name] = module.to("cuda", non_blocking=True) + + unet.set_attn_processor(unet_lora_attn_procs) + unet.load_state_dict(tensors, strict=False) + + # load text + handler = TokenEmbeddingsHandler( + [pipe.text_encoder, pipe.text_encoder_2], [pipe.tokenizer, pipe.tokenizer_2] + ) + handler.load_embeddings(os.path.join(local_weights_cache, "embeddings.pti")) + + # load params + with open(os.path.join(local_weights_cache, "special_params.json"), "r") as f: + params = json.load(f) + + self.token_map = params + self.tuned_weights = weights + self.tuned_model = True + + def unload_trained_weights(self, pipe: DiffusionPipeline): + print("unloading loras") + + def _recursive_unset_lora(module: torch.nn.Module): + if hasattr(module, "lora_layer"): + module.lora_layer = None + + for _, child in module.named_children(): + _recursive_unset_lora(child) + + _recursive_unset_lora(pipe.unet) + self.tuned_weights = None + self.tuned_model = False + + def setup(self, weights: Optional[Path] = None): + """Load the model into memory to make running multiple predictions efficient""" + + start = time.time() + self.tuned_model = False + self.tuned_weights = None + if str(weights) == "weights": + weights = None + + self.weights_cache = WeightsDownloadCache() + + print("Loading safety checker...") + if not os.path.exists(SAFETY_CACHE): + download_weights(SAFETY_URL, SAFETY_CACHE) + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( + SAFETY_CACHE, torch_dtype=torch.float16 + ).to("cuda") + self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) + + if not os.path.exists(SDXL_MODEL_CACHE): + download_weights(SDXL_URL, SDXL_MODEL_CACHE) + + print("Loading sdxl txt2img pipeline...") + self.txt2img_pipe = DiffusionPipeline.from_pretrained( + SDXL_MODEL_CACHE, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", + ) + self.is_lora = False + if weights or os.path.exists("./trained-model"): + self.load_trained_weights(weights, self.txt2img_pipe) + + self.txt2img_pipe.to("cuda") + + print("Loading SDXL img2img pipeline...") + self.img2img_pipe = StableDiffusionXLImg2ImgPipeline( + vae=self.txt2img_pipe.vae, + text_encoder=self.txt2img_pipe.text_encoder, + text_encoder_2=self.txt2img_pipe.text_encoder_2, + tokenizer=self.txt2img_pipe.tokenizer, + tokenizer_2=self.txt2img_pipe.tokenizer_2, + unet=self.txt2img_pipe.unet, + scheduler=self.txt2img_pipe.scheduler, + ) + self.img2img_pipe.to("cuda") + + print("Loading SDXL inpaint pipeline...") + self.inpaint_pipe = StableDiffusionXLInpaintPipeline( + vae=self.txt2img_pipe.vae, + text_encoder=self.txt2img_pipe.text_encoder, + text_encoder_2=self.txt2img_pipe.text_encoder_2, + tokenizer=self.txt2img_pipe.tokenizer, + tokenizer_2=self.txt2img_pipe.tokenizer_2, + unet=self.txt2img_pipe.unet, + scheduler=self.txt2img_pipe.scheduler, + ) + self.inpaint_pipe.to("cuda") + + print("Loading SDXL refiner pipeline...") + # FIXME(ja): should the vae/text_encoder_2 be loaded from SDXL always? + # - in the case of fine-tuned SDXL should we still? + # FIXME(ja): if the answer to above is use VAE/Text_Encoder_2 from fine-tune + # what does this imply about lora + refiner? does the refiner need to know about + + if not os.path.exists(REFINER_MODEL_CACHE): + download_weights(REFINER_URL, REFINER_MODEL_CACHE) + + print("Loading refiner pipeline...") + self.refiner = DiffusionPipeline.from_pretrained( + REFINER_MODEL_CACHE, + text_encoder_2=self.txt2img_pipe.text_encoder_2, + vae=self.txt2img_pipe.vae, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", + ) + self.refiner.to("cuda") + print("setup took: ", time.time() - start) + # self.txt2img_pipe.__class__.encode_prompt = new_encode_prompt + + def load_image(self, path): + shutil.copyfile(path, "/tmp/image.png") + return load_image("/tmp/image.png").convert("RGB") + + def run_safety_checker(self, image): + safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( + "cuda" + ) + np_image = [np.array(val) for val in image] + image, has_nsfw_concept = self.safety_checker( + images=np_image, + clip_input=safety_checker_input.pixel_values.to(torch.float16), + ) + return image, has_nsfw_concept + + @torch.inference_mode() + def predict( + self, + prompt: str = Input( + description="Input prompt", + default="An astronaut riding a rainbow unicorn", + ), + negative_prompt: str = Input( + description="Input Negative Prompt", + default="", + ), + image: Path = Input( + description="Input image for img2img or inpaint mode", + default=None, + ), + mask: Path = Input( + description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.", + default=None, + ), + width: int = Input( + description="Width of output image", + default=1024, + ), + height: int = Input( + description="Height of output image", + default=1024, + ), + num_outputs: int = Input( + description="Number of images to output.", + ge=1, + le=4, + default=1, + ), + scheduler: str = Input( + description="scheduler", + choices=SCHEDULERS.keys(), + default="K_EULER", + ), + num_inference_steps: int = Input( + description="Number of denoising steps", ge=1, le=500, default=50 + ), + guidance_scale: float = Input( + description="Scale for classifier-free guidance", ge=1, le=50, default=7.5 + ), + prompt_strength: float = Input( + description="Prompt strength when using img2img / inpaint. 1.0 corresponds to full destruction of information in image", + ge=0.0, + le=1.0, + default=0.8, + ), + seed: int = Input( + description="Random seed. Leave blank to randomize the seed", default=None + ), + refine: str = Input( + description="Which refine style to use", + choices=["no_refiner", "expert_ensemble_refiner", "base_image_refiner"], + default="no_refiner", + ), + high_noise_frac: float = Input( + description="For expert_ensemble_refiner, the fraction of noise to use", + default=0.8, + le=1.0, + ge=0.0, + ), + refine_steps: int = Input( + description="For base_image_refiner, the number of steps to refine, defaults to num_inference_steps", + default=None, + ), + apply_watermark: bool = Input( + description="Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.", + default=True, + ), + lora_scale: float = Input( + description="LoRA additive scale. Only applicable on trained models.", + ge=0.0, + le=1.0, + default=0.6, + ), + replicate_weights: str = Input( + description="Replicate LoRA weights to use. Leave blank to use the default weights.", + default=None, + ), + disable_safety_checker: bool = Input( + description="Disable safety checker for generated images. This feature is only available through the API. See [https://replicate.com/docs/how-does-replicate-work#safety](https://replicate.com/docs/how-does-replicate-work#safety)", + default=False, + ), + ) -> List[Path]: + """Run a single prediction on the model.""" + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + print(f"Using seed: {seed}") + + if replicate_weights: + self.load_trained_weights(replicate_weights, self.txt2img_pipe) + elif self.tuned_model: + self.unload_trained_weights(self.txt2img_pipe) + + # OOMs can leave vae in bad state + if self.txt2img_pipe.vae.dtype == torch.float32: + self.txt2img_pipe.vae.to(dtype=torch.float16) + + sdxl_kwargs = {} + if self.tuned_model: + # consistency with fine-tuning API + for k, v in self.token_map.items(): + prompt = prompt.replace(k, v) + print(f"Prompt: {prompt}") + if image and mask: + print("inpainting mode") + sdxl_kwargs["image"] = self.load_image(image) + sdxl_kwargs["mask_image"] = self.load_image(mask) + sdxl_kwargs["strength"] = prompt_strength + sdxl_kwargs["width"] = width + sdxl_kwargs["height"] = height + pipe = self.inpaint_pipe + elif image: + print("img2img mode") + sdxl_kwargs["image"] = self.load_image(image) + sdxl_kwargs["strength"] = prompt_strength + pipe = self.img2img_pipe + else: + print("txt2img mode") + sdxl_kwargs["width"] = width + sdxl_kwargs["height"] = height + pipe = self.txt2img_pipe + + if refine == "expert_ensemble_refiner": + sdxl_kwargs["output_type"] = "latent" + sdxl_kwargs["denoising_end"] = high_noise_frac + elif refine == "base_image_refiner": + sdxl_kwargs["output_type"] = "latent" + + if not apply_watermark: + # toggles watermark for this prediction + watermark_cache = pipe.watermark + pipe.watermark = None + self.refiner.watermark = None + + pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config) + generator = torch.Generator("cuda").manual_seed(seed) + + common_args = { + "prompt": [prompt] * num_outputs, + "negative_prompt": [negative_prompt] * num_outputs, + "guidance_scale": guidance_scale, + "generator": generator, + "num_inference_steps": num_inference_steps, + } + + if self.is_lora: + sdxl_kwargs["cross_attention_kwargs"] = {"scale": lora_scale} + + output = pipe(**common_args, **sdxl_kwargs) + + if refine in ["expert_ensemble_refiner", "base_image_refiner"]: + refiner_kwargs = { + "image": output.images, + } + + if refine == "expert_ensemble_refiner": + refiner_kwargs["denoising_start"] = high_noise_frac + if refine == "base_image_refiner" and refine_steps: + common_args["num_inference_steps"] = refine_steps + + output = self.refiner(**common_args, **refiner_kwargs) + + if not apply_watermark: + pipe.watermark = watermark_cache + self.refiner.watermark = watermark_cache + + if not disable_safety_checker: + _, has_nsfw_content = self.run_safety_checker(output.images) + + output_paths = [] + for i, image in enumerate(output.images): + if not disable_safety_checker: + if has_nsfw_content[i]: + print(f"NSFW content detected in image {i}") + continue + output_path = f"/tmp/out-{i}.png" + image.save(output_path) + output_paths.append(Path(output_path)) + + if len(output_paths) == 0: + raise Exception( + f"NSFW content detected. Try running it again, or try a different prompt." + ) + + return output_paths diff --git a/cog_sdxl/preprocess.py b/cog_sdxl/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..8865cf2997f68224faa52f0e3a365ab6009ae39b --- /dev/null +++ b/cog_sdxl/preprocess.py @@ -0,0 +1,599 @@ +# Have SwinIR upsample +# Have BLIP auto caption +# Have CLIPSeg auto mask concept + +import gc +import fnmatch +import mimetypes +import os +import re +import shutil +import tarfile +from pathlib import Path +from typing import List, Literal, Optional, Tuple, Union +from zipfile import ZipFile + +import cv2 +import mediapipe as mp +import numpy as np +import pandas as pd +import torch +from PIL import Image, ImageFilter +from tqdm import tqdm +from transformers import ( + BlipForConditionalGeneration, + BlipProcessor, + CLIPSegForImageSegmentation, + CLIPSegProcessor, + Swin2SRForImageSuperResolution, + Swin2SRImageProcessor, +) + +from predict import download_weights + +# model is fixed to Salesforce/blip-image-captioning-large +BLIP_URL = "https://weights.replicate.delivery/default/blip_large/blip_large.tar" +BLIP_PROCESSOR_URL = ( + "https://weights.replicate.delivery/default/blip_processor/blip_processor.tar" +) +BLIP_PATH = "./blip-cache" +BLIP_PROCESSOR_PATH = "./blip-proc-cache" + +# model is fixed to CIDAS/clipseg-rd64-refined +CLIPSEG_URL = "https://weights.replicate.delivery/default/clip_seg_rd64_refined/clip_seg_rd64_refined.tar" +CLIPSEG_PROCESSOR = "https://weights.replicate.delivery/default/clip_seg_processor/clip_seg_processor.tar" +CLIPSEG_PATH = "./clipseg-cache" +CLIPSEG_PROCESSOR_PATH = "./clipseg-proc-cache" + +# model is fixed to caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr +SWIN2SR_URL = "https://weights.replicate.delivery/default/swin2sr_realworld_sr_x4_64_bsrgan_psnr/swin2sr_realworld_sr_x4_64_bsrgan_psnr.tar" +SWIN2SR_PATH = "./swin2sr-cache" + +TEMP_OUT_DIR = "./temp/" +TEMP_IN_DIR = "./temp_in/" + +CSV_MATCH = "caption" + + +def preprocess( + input_images_filetype: str, + input_zip_path: Path, + caption_text: str, + mask_target_prompts: str, + target_size: int, + crop_based_on_salience: bool, + use_face_detection_instead: bool, + temp: float, + substitution_tokens: List[str], +) -> Path: + # assert str(files).endswith(".zip"), "files must be a zip file" + + # clear TEMP_IN_DIR first. + + for path in [TEMP_OUT_DIR, TEMP_IN_DIR]: + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path) + + caption_csv = None + + if input_images_filetype == "zip" or str(input_zip_path).endswith(".zip"): + with ZipFile(str(input_zip_path), "r") as zip_ref: + for zip_info in zip_ref.infolist(): + if zip_info.filename[-1] == "/" or zip_info.filename.startswith( + "__MACOSX" + ): + continue + mt = mimetypes.guess_type(zip_info.filename) + if mt and mt[0] and mt[0].startswith("image/"): + zip_info.filename = os.path.basename(zip_info.filename) + zip_ref.extract(zip_info, TEMP_IN_DIR) + if ( + mt + and mt[0] + and mt[0] == "text/csv" + and CSV_MATCH in zip_info.filename + ): + zip_info.filename = os.path.basename(zip_info.filename) + zip_ref.extract(zip_info, TEMP_IN_DIR) + caption_csv = os.path.join(TEMP_IN_DIR, zip_info.filename) + elif input_images_filetype == "tar" or str(input_zip_path).endswith(".tar"): + assert str(input_zip_path).endswith( + ".tar" + ), "files must be a tar file if not zip" + with tarfile.open(input_zip_path, "r") as tar_ref: + for tar_info in tar_ref: + if tar_info.name[-1] == "/" or tar_info.name.startswith("__MACOSX"): + continue + + mt = mimetypes.guess_type(tar_info.name) + if mt and mt[0] and mt[0].startswith("image/"): + tar_info.name = os.path.basename(tar_info.name) + tar_ref.extract(tar_info, TEMP_IN_DIR) + if mt and mt[0] and mt[0] == "text/csv" and CSV_MATCH in tar_info.name: + tar_info.name = os.path.basename(tar_info.name) + tar_ref.extract(tar_info, TEMP_IN_DIR) + caption_csv = os.path.join(TEMP_IN_DIR, tar_info.name) + else: + assert False, "input_images_filetype must be zip or tar" + + output_dir: str = TEMP_OUT_DIR + + load_and_save_masks_and_captions( + files=TEMP_IN_DIR, + output_dir=output_dir, + caption_text=caption_text, + caption_csv=caption_csv, + mask_target_prompts=mask_target_prompts, + target_size=target_size, + crop_based_on_salience=crop_based_on_salience, + use_face_detection_instead=use_face_detection_instead, + temp=temp, + substitution_tokens=substitution_tokens, + ) + + return Path(TEMP_OUT_DIR) + + +@torch.no_grad() +@torch.cuda.amp.autocast() +def swin_ir_sr( + images: List[Image.Image], + target_size: Optional[Tuple[int, int]] = None, + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + **kwargs, +) -> List[Image.Image]: + """ + Upscales images using SwinIR. Returns a list of PIL images. + If the image is already larger than the target size, it will not be upscaled + and will be returned as is. + + """ + if not os.path.exists(SWIN2SR_PATH): + download_weights(SWIN2SR_URL, SWIN2SR_PATH) + model = Swin2SRForImageSuperResolution.from_pretrained(SWIN2SR_PATH).to(device) + processor = Swin2SRImageProcessor() + + out_images = [] + + for image in tqdm(images): + ori_w, ori_h = image.size + if target_size is not None: + if ori_w >= target_size[0] and ori_h >= target_size[1]: + out_images.append(image) + continue + + inputs = processor(image, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = model(**inputs) + + output = ( + outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() + ) + output = np.moveaxis(output, source=0, destination=-1) + output = (output * 255.0).round().astype(np.uint8) + output = Image.fromarray(output) + + out_images.append(output) + + return out_images + + +@torch.no_grad() +@torch.cuda.amp.autocast() +def clipseg_mask_generator( + images: List[Image.Image], + target_prompts: Union[List[str], str], + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + bias: float = 0.01, + temp: float = 1.0, + **kwargs, +) -> List[Image.Image]: + """ + Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image + """ + + if isinstance(target_prompts, str): + print( + f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images' + ) + + target_prompts = [target_prompts] * len(images) + if not os.path.exists(CLIPSEG_PROCESSOR_PATH): + download_weights(CLIPSEG_PROCESSOR, CLIPSEG_PROCESSOR_PATH) + if not os.path.exists(CLIPSEG_PATH): + download_weights(CLIPSEG_URL, CLIPSEG_PATH) + processor = CLIPSegProcessor.from_pretrained(CLIPSEG_PROCESSOR_PATH) + model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_PATH).to(device) + + masks = [] + + for image, prompt in tqdm(zip(images, target_prompts)): + original_size = image.size + + inputs = processor( + text=[prompt, ""], + images=[image] * 2, + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(device) + + outputs = model(**inputs) + + logits = outputs.logits + probs = torch.nn.functional.softmax(logits / temp, dim=0)[0] + probs = (probs + bias).clamp_(0, 1) + probs = 255 * probs / probs.max() + + # make mask greyscale + mask = Image.fromarray(probs.cpu().numpy()).convert("L") + + # resize mask to original size + mask = mask.resize(original_size) + + masks.append(mask) + + return masks + + +@torch.no_grad() +def blip_captioning_dataset( + images: List[Image.Image], + text: Optional[str] = None, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + substitution_tokens: Optional[List[str]] = None, + **kwargs, +) -> List[str]: + """ + Returns a list of captions for the given images + """ + if not os.path.exists(BLIP_PROCESSOR_PATH): + download_weights(BLIP_PROCESSOR_URL, BLIP_PROCESSOR_PATH) + if not os.path.exists(BLIP_PATH): + download_weights(BLIP_URL, BLIP_PATH) + processor = BlipProcessor.from_pretrained(BLIP_PROCESSOR_PATH) + model = BlipForConditionalGeneration.from_pretrained(BLIP_PATH).to(device) + captions = [] + text = text.strip() + print(f"Input captioning text: {text}") + for image in tqdm(images): + inputs = processor(image, return_tensors="pt").to("cuda") + out = model.generate( + **inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7 + ) + caption = processor.decode(out[0], skip_special_tokens=True) + + # BLIP 2 lowercases all caps tokens. This should properly replace them w/o messing up subwords. I'm sure there's a better way to do this. + for token in substitution_tokens: + print(token) + sub_cap = " " + caption + " " + print(sub_cap) + sub_cap = sub_cap.replace(" " + token.lower() + " ", " " + token + " ") + caption = sub_cap.strip() + + captions.append(text + " " + caption) + print("Generated captions", captions) + return captions + + +def face_mask_google_mediapipe( + images: List[Image.Image], blur_amount: float = 0.0, bias: float = 50.0 +) -> List[Image.Image]: + """ + Returns a list of images with masks on the face parts. + """ + mp_face_detection = mp.solutions.face_detection + mp_face_mesh = mp.solutions.face_mesh + + face_detection = mp_face_detection.FaceDetection( + model_selection=1, min_detection_confidence=0.1 + ) + face_mesh = mp_face_mesh.FaceMesh( + static_image_mode=True, max_num_faces=1, min_detection_confidence=0.1 + ) + + masks = [] + for image in tqdm(images): + image_np = np.array(image) + + # Perform face detection + results_detection = face_detection.process(image_np) + ih, iw, _ = image_np.shape + if results_detection.detections: + for detection in results_detection.detections: + bboxC = detection.location_data.relative_bounding_box + + bbox = ( + int(bboxC.xmin * iw), + int(bboxC.ymin * ih), + int(bboxC.width * iw), + int(bboxC.height * ih), + ) + + # make sure bbox is within image + bbox = ( + max(0, bbox[0]), + max(0, bbox[1]), + min(iw - bbox[0], bbox[2]), + min(ih - bbox[1], bbox[3]), + ) + + print(bbox) + + # Extract face landmarks + face_landmarks = face_mesh.process( + image_np[bbox[1] : bbox[1] + bbox[3], bbox[0] : bbox[0] + bbox[2]] + ).multi_face_landmarks + + # https://github.com/google/mediapipe/issues/1615 + # This was def helpful + indexes = [ + 10, + 338, + 297, + 332, + 284, + 251, + 389, + 356, + 454, + 323, + 361, + 288, + 397, + 365, + 379, + 378, + 400, + 377, + 152, + 148, + 176, + 149, + 150, + 136, + 172, + 58, + 132, + 93, + 234, + 127, + 162, + 21, + 54, + 103, + 67, + 109, + ] + + if face_landmarks: + mask = Image.new("L", (iw, ih), 0) + mask_np = np.array(mask) + + for face_landmark in face_landmarks: + face_landmark = [face_landmark.landmark[idx] for idx in indexes] + landmark_points = [ + (int(l.x * bbox[2]) + bbox[0], int(l.y * bbox[3]) + bbox[1]) + for l in face_landmark + ] + mask_np = cv2.fillPoly( + mask_np, [np.array(landmark_points)], 255 + ) + + mask = Image.fromarray(mask_np) + + # Apply blur to the mask + if blur_amount > 0: + mask = mask.filter(ImageFilter.GaussianBlur(blur_amount)) + + # Apply bias to the mask + if bias > 0: + mask = np.array(mask) + mask = mask + bias * np.ones(mask.shape, dtype=mask.dtype) + mask = np.clip(mask, 0, 255) + mask = Image.fromarray(mask) + + # Convert mask to 'L' mode (grayscale) before saving + mask = mask.convert("L") + + masks.append(mask) + else: + # If face landmarks are not available, add a black mask of the same size as the image + masks.append(Image.new("L", (iw, ih), 255)) + + else: + print("No face detected, adding full mask") + # If no face is detected, add a white mask of the same size as the image + masks.append(Image.new("L", (iw, ih), 255)) + + return masks + + +def _crop_to_square( + image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None +): + cx, cy = com + width, height = image.size + if width > height: + left_possible = max(cx - height / 2, 0) + left = min(left_possible, width - height) + right = left + height + top = 0 + bottom = height + else: + left = 0 + right = width + top_possible = max(cy - width / 2, 0) + top = min(top_possible, height - width) + bottom = top + width + + image = image.crop((left, top, right, bottom)) + + if resize_to: + image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS) + + return image + + +def _center_of_mass(mask: Image.Image): + """ + Returns the center of mass of the mask + """ + x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1])) + mask_np = np.array(mask) + 0.01 + x_ = x * mask_np + y_ = y * mask_np + + x = np.sum(x_) / np.sum(mask_np) + y = np.sum(y_) / np.sum(mask_np) + + return x, y + + +def load_and_save_masks_and_captions( + files: Union[str, List[str]], + output_dir: str = TEMP_OUT_DIR, + caption_text: Optional[str] = None, + caption_csv: Optional[str] = None, + mask_target_prompts: Optional[Union[List[str], str]] = None, + target_size: int = 1024, + crop_based_on_salience: bool = True, + use_face_detection_instead: bool = False, + temp: float = 1.0, + n_length: int = -1, + substitution_tokens: Optional[List[str]] = None, +): + """ + Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images + to output dir. If mask_target_prompts is given, it will generate kinda-segmentation-masks for the prompts and save them as well. + + Example: + >>> x = load_and_save_masks_and_captions( + files="./data/images", + output_dir="./data/masks_and_captions", + caption_text="a photo of", + mask_target_prompts="cat", + target_size=768, + crop_based_on_salience=True, + use_face_detection_instead=False, + temp=1.0, + n_length=-1, + ) + """ + os.makedirs(output_dir, exist_ok=True) + + # load images + if isinstance(files, str): + # check if it is a directory + if os.path.isdir(files): + # get all the .png .jpg in the directory + files = ( + _find_files("*.png", files) + + _find_files("*.jpg", files) + + _find_files("*.jpeg", files) + ) + + if len(files) == 0: + raise Exception( + f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg/jpeg files." + ) + if n_length == -1: + n_length = len(files) + files = sorted(files)[:n_length] + print("Image files: ", files) + images = [Image.open(file).convert("RGB") for file in files] + + # captions + if caption_csv: + print(f"Using provided captions") + caption_df = pd.read_csv(caption_csv) + # sort images to be consistent with 'sorted' above + caption_df = caption_df.sort_values("image_file") + captions = caption_df["caption"].values + print("Captions: ", captions) + if len(captions) != len(images): + print("Not the same number of captions as images!") + print(f"Num captions: {len(captions)}, Num images: {len(images)}") + print("Captions: ", captions) + print("Images: ", files) + raise Exception( + "Not the same number of captions as images! Check that all files passed in have a caption in your caption csv, and vice versa" + ) + + else: + print(f"Generating {len(images)} captions...") + captions = blip_captioning_dataset( + images, text=caption_text, substitution_tokens=substitution_tokens + ) + + if mask_target_prompts is None: + mask_target_prompts = "" + temp = 999 + + print(f"Generating {len(images)} masks...") + if not use_face_detection_instead: + seg_masks = clipseg_mask_generator( + images=images, target_prompts=mask_target_prompts, temp=temp + ) + else: + seg_masks = face_mask_google_mediapipe(images=images) + + # find the center of mass of the mask + if crop_based_on_salience: + coms = [_center_of_mass(mask) for mask in seg_masks] + else: + coms = [(image.size[0] / 2, image.size[1] / 2) for image in images] + # based on the center of mass, crop the image to a square + images = [ + _crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms) + ] + + print(f"Upscaling {len(images)} images...") + # upscale images anyways + images = swin_ir_sr(images, target_size=(target_size, target_size)) + images = [ + image.resize((target_size, target_size), Image.Resampling.LANCZOS) + for image in images + ] + + seg_masks = [ + _crop_to_square(mask, com, resize_to=target_size) + for mask, com in zip(seg_masks, coms) + ] + + data = [] + + # clean TEMP_OUT_DIR first + if os.path.exists(output_dir): + for file in os.listdir(output_dir): + os.remove(os.path.join(output_dir, file)) + + os.makedirs(output_dir, exist_ok=True) + + # iterate through the images, masks, and captions and add a row to the dataframe for each + for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)): + image_name = f"{idx}.src.png" + mask_file = f"{idx}.mask.png" + + # save the image and mask files + image.save(output_dir + image_name) + mask.save(output_dir + mask_file) + + # add a new row to the dataframe with the file names and caption + data.append( + {"image_path": image_name, "mask_path": mask_file, "caption": caption}, + ) + + df = pd.DataFrame(columns=["image_path", "mask_path", "caption"], data=data) + # save the dataframe to a CSV file + df.to_csv(os.path.join(output_dir, "captions.csv"), index=False) + + +def _find_files(pattern, dir="."): + """Return list of files matching pattern in a given directory, in absolute format. + Unlike glob, this is case-insensitive. + """ + + rule = re.compile(fnmatch.translate(pattern), re.IGNORECASE) + return [os.path.join(dir, f) for f in os.listdir(dir) if rule.match(f)] diff --git a/cog_sdxl/requirements_test.txt b/cog_sdxl/requirements_test.txt new file mode 100644 index 0000000000000000000000000000000000000000..3184d454bd8f198747ca659aa0820480e119b551 --- /dev/null +++ b/cog_sdxl/requirements_test.txt @@ -0,0 +1,5 @@ +numpy +pytest +replicate +requests +Pillow \ No newline at end of file diff --git a/cog_sdxl/samples.py b/cog_sdxl/samples.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b28989e16bdc9d56c7bdaaab6487de44e956ed --- /dev/null +++ b/cog_sdxl/samples.py @@ -0,0 +1,155 @@ +""" +A handy utility for verifying SDXL image generation locally. +To set up, first run a local cog server using: + cog run -p 5000 python -m cog.server.http +Then, in a separate terminal, generate samples + python samples.py +""" + + +import base64 +import os +import sys + +import requests + + +def gen(output_fn, **kwargs): + if os.path.exists(output_fn): + return + + print("Generating", output_fn) + url = "http://localhost:5000/predictions" + response = requests.post(url, json={"input": kwargs}) + data = response.json() + + try: + datauri = data["output"][0] + base64_encoded_data = datauri.split(",")[1] + data = base64.b64decode(base64_encoded_data) + except: + print("Error!") + print("input:", kwargs) + print(data["logs"]) + sys.exit(1) + + with open(output_fn, "wb") as f: + f.write(data) + + +def main(): + SCHEDULERS = [ + "DDIM", + "DPMSolverMultistep", + "HeunDiscrete", + "KarrasDPM", + "K_EULER_ANCESTRAL", + "K_EULER", + "PNDM", + ] + + gen( + f"sample.txt2img.png", + prompt="A studio portrait photo of a cat", + num_inference_steps=25, + guidance_scale=7, + negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", + seed=1000, + width=1024, + height=1024, + ) + + for refiner in ["base_image_refiner", "expert_ensemble_refiner", "no_refiner"]: + gen( + f"sample.img2img.{refiner}.png", + prompt="a photo of an astronaut riding a horse on mars", + image="https://huggingface.co./datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png", + prompt_strength=0.8, + num_inference_steps=25, + refine=refiner, + guidance_scale=7, + negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", + seed=42, + ) + + gen( + f"sample.inpaint.{refiner}.png", + prompt="A majestic tiger sitting on a bench", + image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", + mask="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png", + prompt_strength=0.8, + num_inference_steps=25, + refine=refiner, + guidance_scale=7, + negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", + seed=42, + ) + + for split in range(0, 10): + split = split / 10.0 + gen( + f"sample.expert_ensemble_refiner.{split}.txt2img.png", + prompt="A studio portrait photo of a cat", + num_inference_steps=25, + guidance_scale=7, + refine="expert_ensemble_refiner", + high_noise_frac=split, + negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", + seed=1000, + width=1024, + height=1024, + ) + + gen( + f"sample.refine.txt2img.png", + prompt="A studio portrait photo of a cat", + num_inference_steps=25, + guidance_scale=7, + refine="base_image_refiner", + negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", + seed=1000, + width=1024, + height=1024, + ) + gen( + f"sample.refine.10.txt2img.png", + prompt="A studio portrait photo of a cat", + num_inference_steps=25, + guidance_scale=7, + refine="base_image_refiner", + refine_steps=10, + negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", + seed=1000, + width=1024, + height=1024, + ) + + gen( + "samples.2.txt2img.png", + prompt="A studio portrait photo of a cat", + num_inference_steps=25, + guidance_scale=7, + negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", + scheduler="KarrasDPM", + num_outputs=2, + seed=1000, + width=1024, + height=1024, + ) + + for s in SCHEDULERS: + gen( + f"sample.{s}.txt2img.png", + prompt="A studio portrait photo of a cat", + num_inference_steps=25, + guidance_scale=7, + negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", + scheduler=s, + seed=1000, + width=1024, + height=1024, + ) + + +if __name__ == "__main__": + main() diff --git a/cog_sdxl/script/download_preprocessing_weights.py b/cog_sdxl/script/download_preprocessing_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..541f258b753a8b5b02761db44e49b13e11b8852f --- /dev/null +++ b/cog_sdxl/script/download_preprocessing_weights.py @@ -0,0 +1,54 @@ +import argparse +import os +import shutil + +from transformers import ( + BlipForConditionalGeneration, + BlipProcessor, + CLIPSegForImageSegmentation, + CLIPSegProcessor, + Swin2SRForImageSuperResolution, +) + +DEFAULT_BLIP = "Salesforce/blip-image-captioning-large" +DEFAULT_CLIPSEG = "CIDAS/clipseg-rd64-refined" +DEFAULT_SWINIR = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" + + +def upload(args): + blip_processor = BlipProcessor.from_pretrained(DEFAULT_BLIP) + blip_model = BlipForConditionalGeneration.from_pretrained(DEFAULT_BLIP) + + clip_processor = CLIPSegProcessor.from_pretrained(DEFAULT_CLIPSEG) + clip_model = CLIPSegForImageSegmentation.from_pretrained(DEFAULT_CLIPSEG) + + swin_model = Swin2SRForImageSuperResolution.from_pretrained(DEFAULT_SWINIR) + + temp_models = "tmp/models" + if os.path.exists(temp_models): + shutil.rmtree(temp_models) + os.makedirs(temp_models) + + blip_processor.save_pretrained(os.path.join(temp_models, "blip_processor")) + blip_model.save_pretrained(os.path.join(temp_models, "blip_large")) + clip_processor.save_pretrained(os.path.join(temp_models, "clip_seg_processor")) + clip_model.save_pretrained(os.path.join(temp_models, "clip_seg_rd64_refined")) + swin_model.save_pretrained( + os.path.join(temp_models, "swin2sr_realworld_sr_x4_64_bsrgan_psnr") + ) + + for val in os.listdir(temp_models): + if "tar" not in val: + os.system( + f"sudo tar -cvf {os.path.join(temp_models, val)}.tar -C {os.path.join(temp_models, val)} ." + ) + os.system( + f"gcloud storage cp -R {os.path.join(temp_models, val)}.tar gs://{args.bucket}/{val}/" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--bucket", "-m", type=str) + args = parser.parse_args() + upload(args) diff --git a/cog_sdxl/script/download_weights.py b/cog_sdxl/script/download_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e666f91e880c707207be0c52db8c42570ddbeb --- /dev/null +++ b/cog_sdxl/script/download_weights.py @@ -0,0 +1,50 @@ +# Run this before you deploy it on replicate, because if you don't +# whenever you run the model, it will download the weights from the +# internet, which will take a long time. + +import torch +from diffusers import AutoencoderKL, DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) + +# pipe = DiffusionPipeline.from_pretrained( +# "stabilityai/stable-diffusion-xl-base-1.0", +# torch_dtype=torch.float16, +# use_safetensors=True, +# variant="fp16", +# ) + +# pipe.save_pretrained("./cache", safe_serialization=True) + +better_vae = AutoencoderKL.from_pretrained( + "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 +) + +pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + vae=better_vae, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", +) + +pipe.save_pretrained("./sdxl-cache", safe_serialization=True) + +pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-refiner-1.0", + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", +) + +# TODO - we don't need to save all of this and in fact should save just the unet, tokenizer, and config. +pipe.save_pretrained("./refiner-cache", safe_serialization=True) + + +safety = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", + torch_dtype=torch.float16, +) + +safety.save_pretrained("./safety-cache") diff --git a/cog_sdxl/tests/assets/out.png b/cog_sdxl/tests/assets/out.png new file mode 100644 index 0000000000000000000000000000000000000000..fd8b9de8476925ae47944b3046655be8c1b377fe --- /dev/null +++ b/cog_sdxl/tests/assets/out.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8fe96688cb1e33a7a99eed8645529eb900c6b7b4f9afeaeee8f4bf0afd762df +size 1520752 diff --git a/cog_sdxl/tests/test_predict.py b/cog_sdxl/tests/test_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..aa73709ef0d30ba52d30153ee86cfbb798671ffc --- /dev/null +++ b/cog_sdxl/tests/test_predict.py @@ -0,0 +1,205 @@ +import base64 +import os +import pickle +import subprocess +import sys +import time +from functools import partial +from io import BytesIO + +import numpy as np +import pytest +import replicate +import requests +from PIL import Image, ImageChops + +ENV = os.getenv('TEST_ENV', 'local') +LOCAL_ENDPOINT = "http://localhost:5000/predictions" +MODEL = os.getenv('STAGING_MODEL', 'no model configured') + +def local_run(model_endpoint: str, model_input: dict): + response = requests.post(model_endpoint, json={"input": model_input}) + data = response.json() + + try: + # TODO: this will break if we test batching + datauri = data["output"][0] + base64_encoded_data = datauri.split(",")[1] + data = base64.b64decode(base64_encoded_data) + return Image.open(BytesIO(data)) + except Exception as e: + print("Error!") + print("input:", model_input) + print(data["logs"]) + raise e + + +def replicate_run(model: str, version: str, model_input: dict): + output = replicate.run( + f"{model}:{version}", + input=model_input) + url = output[0] + + response = requests.get(url) + return Image.open(BytesIO(response.content)) + + +def wait_for_server_to_be_ready(url, timeout=300): + """ + Waits for the server to be ready. + + Args: + - url: The health check URL to poll. + - timeout: Maximum time (in seconds) to wait for the server to be ready. + """ + start_time = time.time() + while True: + try: + response = requests.get(url) + data = response.json() + + if data["status"] == "READY": + return + elif data["status"] == "SETUP_FAILED": + raise RuntimeError( + "Server initialization failed with status: SETUP_FAILED" + ) + + except requests.RequestException: + pass + + if time.time() - start_time > timeout: + raise TimeoutError("Server did not become ready in the expected time.") + + time.sleep(5) # Poll every 5 seconds + + +@pytest.fixture(scope="session") +def inference_func(): + """ + local inference uses http API to hit local server; staging inference uses python API b/c it's cleaner. + """ + if ENV == 'local': + return partial(local_run, LOCAL_ENDPOINT) + elif ENV == 'staging': + model = replicate.models.get(MODEL) + print(f"model,", model) + version = model.versions.list()[0] + return partial(replicate_run, MODEL, version.id) + else: + raise Exception(f"env should be local or staging but was {ENV}") + + +@pytest.fixture(scope="session", autouse=True) +def service(): + """ + Spins up local cog server to hit for tests if running locally, no-op otherwise + """ + if ENV == 'local': + print("building model") + # starts local server if we're running things locally + build_command = 'cog build -t test-model'.split() + subprocess.run(build_command, check=True) + container_name = 'cog-test' + try: + subprocess.check_output(['docker', 'inspect', '--format="{{.State.Running}}"', container_name]) + print(f"Container '{container_name}' is running. Stopping and removing...") + subprocess.check_call(['docker', 'stop', container_name]) + subprocess.check_call(['docker', 'rm', container_name]) + print(f"Container '{container_name}' stopped and removed.") + except subprocess.CalledProcessError: + # Container not found + print(f"Container '{container_name}' not found or not running.") + + run_command = f'docker run -d -p 5000:5000 --gpus all --name {container_name} test-model '.split() + process = subprocess.Popen(run_command, stdout=sys.stdout, stderr=sys.stderr) + + wait_for_server_to_be_ready("http://localhost:5000/health-check") + + yield + process.terminate() + process.wait() + stop_command = "docker stop cog-test".split() + subprocess.run(stop_command) + else: + yield + + +def image_equal_fuzzy(img_expected, img_actual, test_name='default', tol=20): + """ + Assert that average pixel values differ by less than tol across image + Tol determined empirically - holding everything else equal but varying seed + generates images that vary by at least 50 + """ + img1 = np.array(img_expected, dtype=np.int32) + img2 = np.array(img_actual, dtype=np.int32) + + mean_delta = np.mean(np.abs(img1 - img2)) + imgs_equal = (mean_delta < tol) + if not imgs_equal: + # save failures for quick inspection + save_dir = f"tmp/{test_name}" + if not os.path.exists(save_dir): + os.makedirs(save_dir) + img_expected.save(os.path.join(save_dir, 'expected.png')) + img_actual.save(os.path.join(save_dir, 'actual.png')) + difference = ImageChops.difference(img_expected, img_actual) + difference.save(os.path.join(save_dir, 'delta.png')) + + return imgs_equal + + +def test_seeded_prediction(inference_func, request): + """ + SDXL w/seed should be deterministic. may need to adjust tolerance for optimized SDXLs + """ + data = { + "prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic", + "num_inference_steps": 50, + "width": 1024, + "height": 1024, + "scheduler": "DDIM", + "refine": "expert_ensemble_refiner", + "seed": 12103, + } + actual_image = inference_func(data) + expected_image = Image.open("tests/assets/out.png") + assert image_equal_fuzzy(actual_image, expected_image, test_name=request.node.name) + + +def test_lora_load_unload(inference_func, request): + """ + Tests generation with & without loras. + This is checking for some gnarly state issues (can SDXL load / unload LoRAs), so predictions need to run in series. + """ + SEED = 1234 + base_data = { + "prompt": "A photo of a dog on the beach", + "num_inference_steps": 50, + # Add other parameters here + "seed": SEED, + } + base_img_1 = inference_func(base_data) + + lora_a_data = { + "prompt": "A photo of a TOK on the beach", + "num_inference_steps": 50, + # Add other parameters here + "replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/other_model.tar", + "seed": SEED + } + lora_a_img_1 = inference_func(lora_a_data) + assert not image_equal_fuzzy(lora_a_img_1, base_img_1, test_name=request.node.name) + + lora_a_img_2 = inference_func(lora_a_data) + assert image_equal_fuzzy(lora_a_img_1, lora_a_img_2, test_name=request.node.name) + + lora_b_data = { + "prompt": "A photo of a TOK on the beach", + "num_inference_steps": 50, + "replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/monstertoy_model.tar", + "seed": SEED, + } + lora_b_img = inference_func(lora_b_data) + assert not image_equal_fuzzy(lora_a_img_1, lora_b_img, test_name=request.node.name) + assert not image_equal_fuzzy(base_img_1, lora_b_img, test_name=request.node.name) diff --git a/cog_sdxl/tests/test_remote_train.py b/cog_sdxl/tests/test_remote_train.py new file mode 100644 index 0000000000000000000000000000000000000000..26a2792833aad3264d918de497d91aaef05be6d7 --- /dev/null +++ b/cog_sdxl/tests/test_remote_train.py @@ -0,0 +1,69 @@ +import time +import pytest +import replicate + + +@pytest.fixture(scope="module") +def model_name(request): + return "stability-ai/sdxl" + + +@pytest.fixture(scope="module") +def model(model_name): + return replicate.models.get(model_name) + + +@pytest.fixture(scope="module") +def version(model): + versions = model.versions.list() + return versions[0] + + +@pytest.fixture(scope="module") +def training(model_name, version): + training_input = { + "input_images": "https://storage.googleapis.com/replicate-datasets/sdxl-test/monstertoy-captions.tar" + } + print(f"Training on {model_name}:{version.id}") + return replicate.trainings.create( + version=model_name + ":" + version.id, + input=training_input, + destination="replicate-internal/training-scratch", + ) + + +@pytest.fixture(scope="module") +def prediction_tests(): + return [ + { + "prompt": "A photo of TOK at the beach", + "refine": "expert_ensemble_refiner", + }, + ] + + +def test_training(training): + while training.completed_at is None: + time.sleep(60) + training.reload() + assert training.status == "succeeded" + + +@pytest.fixture(scope="module") +def trained_model_and_version(training): + trained_model, trained_version = training.output["version"].split(":") + return trained_model, trained_version + + +def test_post_training_predictions(trained_model_and_version, prediction_tests): + trained_model, trained_version = trained_model_and_version + model = replicate.models.get(trained_model) + version = model.versions.get(trained_version) + predictions = [ + replicate.predictions.create(version=version, input=val) + for val in prediction_tests + ] + + for val in predictions: + val.wait() + assert val.status == "succeeded" diff --git a/cog_sdxl/tests/test_utils.py b/cog_sdxl/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7d4e828c76d24c2b62e5ee66b516ee0337205a --- /dev/null +++ b/cog_sdxl/tests/test_utils.py @@ -0,0 +1,105 @@ +import os +import json +import requests +import time +from threading import Thread, Lock +import re +import multiprocessing +import subprocess + +ERROR_PATTERN = re.compile(r"ERROR:") + + +def get_image_name(): + current_dir = os.path.basename(os.getcwd()) + + if "cog" in current_dir: + return current_dir + else: + return f"cog-{current_dir}" + + +def process_log_line(line): + line = line.decode("utf-8").strip() + try: + log_data = json.loads(line) + return json.dumps(log_data, indent=2) + except json.JSONDecodeError: + return line + + +def capture_output(pipe, print_lock, logs=None, error_detected=None): + for line in iter(pipe.readline, b""): + formatted_line = process_log_line(line) + with print_lock: + print(formatted_line) + if logs is not None: + logs.append(formatted_line) + if error_detected is not None: + if ERROR_PATTERN.search(formatted_line): + error_detected[0] = True + + +def wait_for_server_to_be_ready(url, timeout=300): + """ + Waits for the server to be ready. + + Args: + - url: The health check URL to poll. + - timeout: Maximum time (in seconds) to wait for the server to be ready. + """ + start_time = time.time() + while True: + try: + response = requests.get(url) + data = response.json() + + if data["status"] == "READY": + return + elif data["status"] == "SETUP_FAILED": + raise RuntimeError( + "Server initialization failed with status: SETUP_FAILED" + ) + + except requests.RequestException: + pass + + if time.time() - start_time > timeout: + raise TimeoutError("Server did not become ready in the expected time.") + + time.sleep(5) # Poll every 5 seconds + + +def run_training_subprocess(command): + # Start the subprocess with pipes for stdout and stderr + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + # Create a lock for printing and a list to accumulate logs + print_lock = multiprocessing.Lock() + logs = multiprocessing.Manager().list() + error_detected = multiprocessing.Manager().list([False]) + + # Start two separate processes to handle stdout and stderr + stdout_processor = multiprocessing.Process( + target=capture_output, args=(process.stdout, print_lock, logs, error_detected) + ) + stderr_processor = multiprocessing.Process( + target=capture_output, args=(process.stderr, print_lock, logs, error_detected) + ) + + # Start the log processors + stdout_processor.start() + stderr_processor.start() + + # Wait for the subprocess to finish + process.wait() + + # Wait for the log processors to finish + stdout_processor.join() + stderr_processor.join() + + # Check if an error pattern was detected + if error_detected[0]: + raise Exception("Error detected in training logs! Check logs for details") + + return list(logs) diff --git a/cog_sdxl/train.py b/cog_sdxl/train.py new file mode 100644 index 0000000000000000000000000000000000000000..331327b56625c0d20ffd42a2ac084f5d37191c83 --- /dev/null +++ b/cog_sdxl/train.py @@ -0,0 +1,197 @@ +import os +import shutil +import tarfile + +from cog import BaseModel, Input, Path + +from predict import SDXL_MODEL_CACHE, SDXL_URL, download_weights +from preprocess import preprocess +from trainer_pti import main + +""" +Wrapper around actual trainer. +""" +OUTPUT_DIR = "training_out" + + +class TrainingOutput(BaseModel): + weights: Path + + +from typing import Tuple + + +def train( + input_images: Path = Input( + description="A .zip or .tar file containing the image files that will be used for fine-tuning" + ), + seed: int = Input( + description="Random seed for reproducible training. Leave empty to use a random seed", + default=None, + ), + resolution: int = Input( + description="Square pixel resolution which your images will be resized to for training", + default=768, + ), + train_batch_size: int = Input( + description="Batch size (per device) for training", + default=4, + ), + num_train_epochs: int = Input( + description="Number of epochs to loop through your training dataset", + default=4000, + ), + max_train_steps: int = Input( + description="Number of individual training steps. Takes precedence over num_train_epochs", + default=1000, + ), + # gradient_accumulation_steps: int = Input( + # description="Number of training steps to accumulate before a backward pass. Effective batch size = gradient_accumulation_steps * batch_size", + # default=1, + # ), # todo. + is_lora: bool = Input( + description="Whether to use LoRA training. If set to False, will use Full fine tuning", + default=True, + ), + unet_learning_rate: float = Input( + description="Learning rate for the U-Net. We recommend this value to be somewhere between `1e-6` to `1e-5`.", + default=1e-6, + ), + ti_lr: float = Input( + description="Scaling of learning rate for training textual inversion embeddings. Don't alter unless you know what you're doing.", + default=3e-4, + ), + lora_lr: float = Input( + description="Scaling of learning rate for training LoRA embeddings. Don't alter unless you know what you're doing.", + default=1e-4, + ), + lora_rank: int = Input( + description="Rank of LoRA embeddings. Don't alter unless you know what you're doing.", + default=32, + ), + lr_scheduler: str = Input( + description="Learning rate scheduler to use for training", + default="constant", + choices=[ + "constant", + "linear", + ], + ), + lr_warmup_steps: int = Input( + description="Number of warmup steps for lr schedulers with warmups.", + default=100, + ), + token_string: str = Input( + description="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well", + default="TOK", + ), + # token_map: str = Input( + # description="String of token and their impact size specificing tokens used in the dataset. This will be in format of `token1:size1,token2:size2,...`.", + # default="TOK:2", + # ), + caption_prefix: str = Input( + description="Text which will be used as prefix during automatic captioning. Must contain the `token_string`. For example, if caption text is 'a photo of TOK', automatic captioning will expand to 'a photo of TOK under a bridge', 'a photo of TOK holding a cup', etc.", + default="a photo of TOK, ", + ), + mask_target_prompts: str = Input( + description="Prompt that describes part of the image that you will find important. For example, if you are fine-tuning your pet, `photo of a dog` will be a good prompt. Prompt-based masking is used to focus the fine-tuning process on the important/salient parts of the image", + default=None, + ), + crop_based_on_salience: bool = Input( + description="If you want to crop the image to `target_size` based on the important parts of the image, set this to True. If you want to crop the image based on face detection, set this to False", + default=True, + ), + use_face_detection_instead: bool = Input( + description="If you want to use face detection instead of CLIPSeg for masking. For face applications, we recommend using this option.", + default=False, + ), + clipseg_temperature: float = Input( + description="How blurry you want the CLIPSeg mask to be. We recommend this value be something between `0.5` to `1.0`. If you want to have more sharp mask (but thus more errorful), you can decrease this value.", + default=1.0, + ), + verbose: bool = Input(description="verbose output", default=True), + checkpointing_steps: int = Input( + description="Number of steps between saving checkpoints. Set to very very high number to disable checkpointing, because you don't need one.", + default=999999, + ), + input_images_filetype: str = Input( + description="Filetype of the input images. Can be either `zip` or `tar`. By default its `infer`, and it will be inferred from the ext of input file.", + default="infer", + choices=["zip", "tar", "infer"], + ), +) -> TrainingOutput: + # Hard-code token_map for now. Make it configurable once we support multiple concepts or user-uploaded caption csv. + token_map = token_string + ":2" + + # Process 'token_to_train' and 'input_data_tar_or_zip' + inserting_list_tokens = token_map.split(",") + + token_dict = {} + running_tok_cnt = 0 + all_token_lists = [] + for token in inserting_list_tokens: + n_tok = int(token.split(":")[1]) + + token_dict[token.split(":")[0]] = "".join( + [f"" for i in range(n_tok)] + ) + all_token_lists.extend([f"" for i in range(n_tok)]) + + running_tok_cnt += n_tok + + input_dir = preprocess( + input_images_filetype=input_images_filetype, + input_zip_path=input_images, + caption_text=caption_prefix, + mask_target_prompts=mask_target_prompts, + target_size=resolution, + crop_based_on_salience=crop_based_on_salience, + use_face_detection_instead=use_face_detection_instead, + temp=clipseg_temperature, + substitution_tokens=list(token_dict.keys()), + ) + + if not os.path.exists(SDXL_MODEL_CACHE): + download_weights(SDXL_URL, SDXL_MODEL_CACHE) + if os.path.exists(OUTPUT_DIR): + shutil.rmtree(OUTPUT_DIR) + os.makedirs(OUTPUT_DIR) + + main( + pretrained_model_name_or_path=SDXL_MODEL_CACHE, + instance_data_dir=os.path.join(input_dir, "captions.csv"), + output_dir=OUTPUT_DIR, + seed=seed, + resolution=resolution, + train_batch_size=train_batch_size, + num_train_epochs=num_train_epochs, + max_train_steps=max_train_steps, + gradient_accumulation_steps=1, + unet_learning_rate=unet_learning_rate, + ti_lr=ti_lr, + lora_lr=lora_lr, + lr_scheduler=lr_scheduler, + lr_warmup_steps=lr_warmup_steps, + token_dict=token_dict, + inserting_list_tokens=all_token_lists, + verbose=verbose, + checkpointing_steps=checkpointing_steps, + scale_lr=False, + max_grad_norm=1.0, + allow_tf32=True, + mixed_precision="bf16", + device="cuda:0", + lora_rank=lora_rank, + is_lora=is_lora, + ) + + directory = Path(OUTPUT_DIR) + out_path = "trained_model.tar" + + with tarfile.open(out_path, "w") as tar: + for file_path in directory.rglob("*"): + print(file_path) + arcname = file_path.relative_to(directory) + tar.add(file_path, arcname=arcname) + + return TrainingOutput(weights=Path(out_path)) diff --git a/cog_sdxl/trainer_pti.py b/cog_sdxl/trainer_pti.py new file mode 100644 index 0000000000000000000000000000000000000000..89b81de8963b6aa588480918186cd318676990ca --- /dev/null +++ b/cog_sdxl/trainer_pti.py @@ -0,0 +1,403 @@ +# Bootstrapped from Huggingface diffuser's code. +import fnmatch +import json +import math +import os +import shutil +from typing import List, Optional + +import numpy as np +import torch +import torch.utils.checkpoint +from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 +from diffusers.optimization import get_scheduler +from safetensors.torch import save_file +from tqdm.auto import tqdm + +from dataset_and_utils import ( + PreprocessedDataset, + TokenEmbeddingsHandler, + load_models, + unet_attn_processors_state_dict, +) + + +def main( + pretrained_model_name_or_path: Optional[ + str + ] = "./cache", # "stabilityai/stable-diffusion-xl-base-1.0", + revision: Optional[str] = None, + instance_data_dir: Optional[str] = "./dataset/zeke/captions.csv", + output_dir: str = "ft_masked_coke", + seed: Optional[int] = 42, + resolution: int = 512, + crops_coords_top_left_h: int = 0, + crops_coords_top_left_w: int = 0, + train_batch_size: int = 1, + do_cache: bool = True, + num_train_epochs: int = 600, + max_train_steps: Optional[int] = None, + checkpointing_steps: int = 500000, # default to no checkpoints + gradient_accumulation_steps: int = 1, # todo + unet_learning_rate: float = 1e-5, + ti_lr: float = 3e-4, + lora_lr: float = 1e-4, + pivot_halfway: bool = True, + scale_lr: bool = False, + lr_scheduler: str = "constant", + lr_warmup_steps: int = 500, + lr_num_cycles: int = 1, + lr_power: float = 1.0, + dataloader_num_workers: int = 0, + max_grad_norm: float = 1.0, # todo with tests + allow_tf32: bool = True, + mixed_precision: Optional[str] = "bf16", + device: str = "cuda:0", + token_dict: dict = {"TOKEN": ""}, + inserting_list_tokens: List[str] = [""], + verbose: bool = True, + is_lora: bool = True, + lora_rank: int = 32, +) -> None: + if allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + if not seed: + seed = np.random.randint(0, 2**32 - 1) + print("Using seed", seed) + torch.manual_seed(seed) + + weight_dtype = torch.float32 + if mixed_precision == "fp16": + weight_dtype = torch.float16 + elif mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if scale_lr: + unet_learning_rate = ( + unet_learning_rate * gradient_accumulation_steps * train_batch_size + ) + + ( + tokenizer_one, + tokenizer_two, + noise_scheduler, + text_encoder_one, + text_encoder_two, + vae, + unet, + ) = load_models(pretrained_model_name_or_path, revision, device, weight_dtype) + + print("# PTI : Loaded models") + + # Initialize new tokens for training. + + embedding_handler = TokenEmbeddingsHandler( + [text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two] + ) + embedding_handler.initialize_new_tokens(inserting_toks=inserting_list_tokens) + + text_encoders = [text_encoder_one, text_encoder_two] + + unet_param_to_optimize = [] + # fine tune only attn weights + + text_encoder_parameters = [] + for text_encoder in text_encoders: + for name, param in text_encoder.named_parameters(): + if "token_embedding" in name: + param.requires_grad = True + print(name) + text_encoder_parameters.append(param) + else: + param.requires_grad = False + + if not is_lora: + WHITELIST_PATTERNS = [ + # "*.attn*.weight", + # "*ff*.weight", + "*" + ] # TODO : make this a parameter + BLACKLIST_PATTERNS = ["*.norm*.weight", "*time*"] + + unet_param_to_optimize_names = [] + for name, param in unet.named_parameters(): + if any( + fnmatch.fnmatch(name, pattern) for pattern in WHITELIST_PATTERNS + ) and not any( + fnmatch.fnmatch(name, pattern) for pattern in BLACKLIST_PATTERNS + ): + param.requires_grad_(True) + unet_param_to_optimize_names.append(name) + print(f"Training: {name}") + else: + param.requires_grad_(False) + + # Optimizer creation + params_to_optimize = [ + { + "params": unet_param_to_optimize, + "lr": unet_learning_rate, + }, + { + "params": text_encoder_parameters, + "lr": ti_lr, + "weight_decay": 1e-3, + }, + ] + + else: + # Do lora-training instead. + unet.requires_grad_(False) + unet_lora_attn_procs = {} + unet_lora_parameters = [] + for name, attn_processor in unet.attn_processors.items(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + module = LoRAAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=lora_rank, + ) + unet_lora_attn_procs[name] = module + module.to(device) + unet_lora_parameters.extend(module.parameters()) + + unet.set_attn_processor(unet_lora_attn_procs) + + params_to_optimize = [ + { + "params": unet_lora_parameters, + "lr": lora_lr, + }, + { + "params": text_encoder_parameters, + "lr": ti_lr, + "weight_decay": 1e-3, + }, + ] + + optimizer = torch.optim.AdamW( + params_to_optimize, + weight_decay=1e-4, + ) + + print(f"# PTI : Loading dataset, do_cache {do_cache}") + + train_dataset = PreprocessedDataset( + instance_data_dir, + tokenizer_one, + tokenizer_two, + vae.float(), + do_cache=True, + substitute_caption_map=token_dict, + ) + + print("# PTI : Loaded dataset") + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + num_workers=dataloader_num_workers, + ) + + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / gradient_accumulation_steps + ) + if max_train_steps is None: + max_train_steps = num_train_epochs * num_update_steps_per_epoch + + lr_scheduler = get_scheduler( + lr_scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + num_cycles=lr_num_cycles, + power=lr_power, + ) + + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / gradient_accumulation_steps + ) + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + + total_batch_size = train_batch_size * gradient_accumulation_steps + + if verbose: + print(f"# PTI : Running training ") + print(f"# PTI : Num examples = {len(train_dataset)}") + print(f"# PTI : Num batches each epoch = {len(train_dataloader)}") + print(f"# PTI : Num Epochs = {num_train_epochs}") + print(f"# PTI : Instantaneous batch size per device = {train_batch_size}") + print( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + print(f"# PTI : Gradient Accumulation steps = {gradient_accumulation_steps}") + print(f"# PTI : Total optimization steps = {max_train_steps}") + + global_step = 0 + first_epoch = 0 + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, max_train_steps)) + checkpoint_dir = "checkpoint" + if os.path.exists(checkpoint_dir): + shutil.rmtree(checkpoint_dir) + + os.makedirs(f"{checkpoint_dir}/unet", exist_ok=True) + os.makedirs(f"{checkpoint_dir}/embeddings", exist_ok=True) + + for epoch in range(first_epoch, num_train_epochs): + if pivot_halfway: + if epoch == num_train_epochs // 2: + print("# PTI : Pivot halfway") + # remove text encoder parameters from optimizer + params_to_optimize = params_to_optimize[:1] + optimizer = torch.optim.AdamW( + params_to_optimize, + weight_decay=1e-4, + ) + + unet.train() + for step, batch in enumerate(train_dataloader): + progress_bar.update(1) + progress_bar.set_description(f"# PTI :step: {global_step}, epoch: {epoch}") + global_step += 1 + + (tok1, tok2), vae_latent, mask = batch + vae_latent = vae_latent.to(weight_dtype) + + # tokens to text embeds + prompt_embeds_list = [] + for tok, text_encoder in zip((tok1, tok2), text_encoders): + prompt_embeds_out = text_encoder( + tok.to(text_encoder.device), + output_hidden_states=True, + ) + + pooled_prompt_embeds = prompt_embeds_out[0] + prompt_embeds = prompt_embeds_out.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + + # Create Spatial-dimensional conditions. + + original_size = (resolution, resolution) + target_size = (resolution, resolution) + crops_coords_top_left = (crops_coords_top_left_h, crops_coords_top_left_w) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + + add_time_ids = add_time_ids.to(device, dtype=prompt_embeds.dtype).repeat( + bs_embed, 1 + ) + + added_kw = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} + + # Sample noise that we'll add to the latents + noise = torch.randn_like(vae_latent) + bsz = vae_latent.shape[0] + + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + device=vae_latent.device, + ) + timesteps = timesteps.long() + + noisy_model_input = noise_scheduler.add_noise(vae_latent, noise, timesteps) + + # Predict the noise residual + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs=added_kw, + ).sample + + loss = (model_pred - noise).pow(2) * mask + loss = loss.mean() + + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # every step, we reset the embeddings to the original embeddings. + + for idx, text_encoder in enumerate(text_encoders): + embedding_handler.retract_embeddings() + + if global_step % checkpointing_steps == 0: + # save the required params of unet with safetensor + + if not is_lora: + tensors = { + name: param + for name, param in unet.named_parameters() + if name in unet_param_to_optimize_names + } + save_file( + tensors, + f"{checkpoint_dir}/unet/checkpoint-{global_step}.unet.safetensors", + ) + + else: + lora_tensors = unet_attn_processors_state_dict(unet) + + save_file( + lora_tensors, + f"{checkpoint_dir}/unet/checkpoint-{global_step}.lora.safetensors", + ) + + embedding_handler.save_embeddings( + f"{checkpoint_dir}/embeddings/checkpoint-{global_step}.pti", + ) + + # final_save + print("Saving final model for return") + if not is_lora: + tensors = { + name: param + for name, param in unet.named_parameters() + if name in unet_param_to_optimize_names + } + save_file( + tensors, + f"{output_dir}/unet.safetensors", + ) + else: + lora_tensors = unet_attn_processors_state_dict(unet) + save_file( + lora_tensors, + f"{output_dir}/lora.safetensors", + ) + + embedding_handler.save_embeddings( + f"{output_dir}/embeddings.pti", + ) + + to_save = token_dict + with open(f"{output_dir}/special_params.json", "w") as f: + json.dump(to_save, f) + + +if __name__ == "__main__": + main() diff --git a/cog_sdxl/weights.py b/cog_sdxl/weights.py new file mode 100644 index 0000000000000000000000000000000000000000..06fb29fa1281531f0866d4677d1727e82a092459 --- /dev/null +++ b/cog_sdxl/weights.py @@ -0,0 +1,127 @@ +from collections import deque +import hashlib +import os +import shutil +import subprocess +import time + + +class WeightsDownloadCache: + def __init__( + self, min_disk_free: int = 10 * (2**30), base_dir: str = "/src/weights-cache" + ): + """ + WeightsDownloadCache is meant to track and download weights files as fast + as possible, while ensuring there's enough disk space. + + It tries to keep the most recently used weights files in the cache, so + ensure you call ensure() on the weights each time you use them. + + It will not re-download weights files that are already in the cache. + + :param min_disk_free: Minimum disk space required to start download, in bytes. + :param base_dir: The base directory to store weights files. + """ + self.min_disk_free = min_disk_free + self.base_dir = base_dir + self._hits = 0 + self._misses = 0 + + # Least Recently Used (LRU) cache for paths + self.lru_paths = deque() + if not os.path.exists(base_dir): + os.makedirs(base_dir) + + def _remove_least_recent(self) -> None: + """ + Remove the least recently used weights file from the cache and disk. + """ + oldest = self.lru_paths.popleft() + self._rm_disk(oldest) + + def cache_info(self) -> str: + """ + Get cache information. + + :return: Cache information. + """ + + return f"CacheInfo(hits={self._hits}, misses={self._misses}, base_dir='{self.base_dir}', currsize={len(self.lru_paths)})" + + def _rm_disk(self, path: str) -> None: + """ + Remove a weights file or directory from disk. + :param path: Path to remove. + """ + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + + def _has_enough_space(self) -> bool: + """ + Check if there's enough disk space. + + :return: True if there's more than min_disk_free free, False otherwise. + """ + disk_usage = shutil.disk_usage(self.base_dir) + print(f"Free disk space: {disk_usage.free}") + return disk_usage.free >= self.min_disk_free + + def ensure(self, url: str) -> str: + """ + Ensure weights file is in the cache and return its path. + + This also updates the LRU cache to mark the weights as recently used. + + :param url: URL to download weights file from, if not in cache. + :return: Path to weights. + """ + path = self.weights_path(url) + + if path in self.lru_paths: + # here we remove to re-add to the end of the LRU (marking it as recently used) + self._hits += 1 + self.lru_paths.remove(path) + else: + self._misses += 1 + self.download_weights(url, path) + + self.lru_paths.append(path) # Add file to end of cache + return path + + def weights_path(self, url: str) -> str: + """ + Generate path to store a weights file based hash of the URL. + + :param url: URL to download weights file from. + :return: Path to store weights file. + """ + hashed_url = hashlib.sha256(url.encode()).hexdigest() + short_hash = hashed_url[:16] # Use the first 16 characters of the hash + return os.path.join(self.base_dir, short_hash) + + def download_weights(self, url: str, dest: str) -> None: + """ + Download weights file from a URL, ensuring there's enough disk space. + + :param url: URL to download weights file from. + :param dest: Path to store weights file. + """ + print("Ensuring enough disk space...") + while not self._has_enough_space() and len(self.lru_paths) > 0: + self._remove_least_recent() + + print(f"Downloading weights: {url}") + + st = time.time() + # maybe retry with the real url if this doesn't work + try: + output = subprocess.check_output(["pget", "-x", url, dest], close_fds=True) + print(output) + except subprocess.CalledProcessError as e: + # If download fails, clean up and re-raise exception + print(e.output) + self._rm_disk(dest) + raise e + print(f"Downloaded weights in {time.time() - st} seconds")