import random import math import os import torch import torch.utils.checkpoint from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm import numpy as np from PIL import Image from trainer_util import * from clip_segmentation import ClipSeg class bcolors: HEADER = '\033[95m' OKBLUE = '\033[94m' OKCYAN = '\033[96m' OKGREEN = '\033[92m' WARNING = '\033[93m' FAIL = '\033[91m' ENDC = '\033[0m' BOLD = '\033[1m' UNDERLINE = '\033[4m' ASPECT_2048 = [[2048, 2048], [2112, 1984],[1984, 2112], [2176, 1920],[1920, 2176], [2240, 1856],[1856, 2240], [2304, 1792],[1792, 2304], [2368, 1728],[1728, 2368], [2432, 1664],[1664, 2432], [2496, 1600],[1600, 2496], [2560, 1536],[1536, 2560], [2624, 1472],[1472, 2624]] ASPECT_1984 = [[1984, 1984], [2048, 1920],[1920, 2048], [2112, 1856],[1856, 2112], [2176, 1792],[1792, 2176], [2240, 1728],[1728, 2240], [2304, 1664],[1664, 2304], [2368, 1600],[1600, 2368], [2432, 1536],[1536, 2432], [2496, 1472],[1472, 2496], [2560, 1408],[1408, 2560]] ASPECT_1920 = [[1920, 1920], [1984, 1856],[1856, 1984], [2048, 1792],[1792, 2048], [2112, 1728],[1728, 2112], [2176, 1664],[1664, 2176], [2240, 1600],[1600, 2240], [2304, 1536],[1536, 2304], [2368, 1472],[1472, 2368], [2432, 1408],[1408, 2432], [2496, 1344],[1344, 2496]] ASPECT_1856 = [[1856, 1856], [1920, 1792],[1792, 1920], [1984, 1728],[1728, 1984], [2048, 1664],[1664, 2048], [2112, 1600],[1600, 2112], [2176, 1536],[1536, 2176], [2240, 1472],[1472, 2240], [2304, 1408],[1408, 2304], [2368, 1344],[1344, 2368], [2432, 1280],[1280, 2432]] ASPECT_1792 = [[1792, 1792], [1856, 1728],[1728, 1856], [1920, 1664],[1664, 1920], [1984, 1600],[1600, 1984], [2048, 1536],[1536, 2048], [2112, 1472],[1472, 2112], [2176, 1408],[1408, 2176], [2240, 1344],[1344, 2240], [2304, 1280],[1280, 2304], [2368, 1216],[1216, 2368]] ASPECT_1728 = [[1728, 1728], [1792, 1664],[1664, 1792], [1856, 1600],[1600, 1856], [1920, 1536],[1536, 1920], [1984, 1472],[1472, 1984], [2048, 1408],[1408, 2048], [2112, 1344],[1344, 2112], [2176, 1280],[1280, 2176], [2240, 1216],[1216, 2240], [2304, 1152],[1152, 2304]] ASPECT_1664 = [[1664, 1664], [1728, 1600],[1600, 1728], [1792, 1536],[1536, 1792], [1856, 1472],[1472, 1856], [1920, 1408],[1408, 1920], [1984, 1344],[1344, 1984], [2048, 1280],[1280, 2048], [2112, 1216],[1216, 2112], [2176, 1152],[1152, 2176], [2240, 1088],[1088, 2240]] ASPECT_1600 = [[1600, 1600], [1664, 1536],[1536, 1664], [1728, 1472],[1472, 1728], [1792, 1408],[1408, 1792], [1856, 1344],[1344, 1856], [1920, 1280],[1280, 1920], [1984, 1216],[1216, 1984], [2048, 1152],[1152, 2048], [2112, 1088],[1088, 2112], [2176, 1024],[1024, 2176]] ASPECT_1536 = [[1536, 1536], [1600, 1472],[1472, 1600], [1664, 1408],[1408, 1664], [1728, 1344],[1344, 1728], [1792, 1280],[1280, 1792], [1856, 1216],[1216, 1856], [1920, 1152],[1152, 1920], [1984, 1088],[1088, 1984], [2048, 1024],[1024, 2048], [2112, 960],[960, 2112]] ASPECT_1472 = [[1472, 1472], [1536, 1408],[1408, 1536], [1600, 1344],[1344, 1600], [1664, 1280],[1280, 1664], [1728, 1216],[1216, 1728], [1792, 1152],[1152, 1792], [1856, 1088],[1088, 1856], [1920, 1024],[1024, 1920], [1984, 960],[960, 1984], [2048, 896],[896, 2048]] ASPECT_1408 = [[1408, 1408], [1472, 1344],[1344, 1472], [1536, 1280],[1280, 1536], [1600, 1216],[1216, 1600], [1664, 1152],[1152, 1664], [1728, 1088],[1088, 1728], [1792, 1024],[1024, 1792], [1856, 960],[960, 1856], [1920, 896],[896, 1920], [1984, 832],[832, 1984]] ASPECT_1344 = [[1344, 1344], [1408, 1280],[1280, 1408], [1472, 1216],[1216, 1472], [1536, 1152],[1152, 1536], [1600, 1088],[1088, 1600], [1664, 1024],[1024, 1664], [1728, 960],[960, 1728], [1792, 896],[896, 1792], [1856, 832],[832, 1856], [1920, 768],[768, 1920]] ASPECT_1280 = [[1280, 1280], [1344, 1216],[1216, 1344], [1408, 1152],[1152, 1408], [1472, 1088],[1088, 1472], [1536, 1024],[1024, 1536], [1600, 960],[960, 1600], [1664, 896],[896, 1664], [1728, 832],[832, 1728], [1792, 768],[768, 1792], [1856, 704],[704, 1856]] ASPECT_1216 = [[1216, 1216], [1280, 1152],[1152, 1280], [1344, 1088],[1088, 1344], [1408, 1024],[1024, 1408], [1472, 960],[960, 1472], [1536, 896],[896, 1536], [1600, 832],[832, 1600], [1664, 768],[768, 1664], [1728, 704],[704, 1728], [1792, 640],[640, 1792]] ASPECT_1152 = [[1152, 1152], [1216, 1088],[1088, 1216], [1280, 1024],[1024, 1280], [1344, 960],[960, 1344], [1408, 896],[896, 1408], [1472, 832],[832, 1472], [1536, 768],[768, 1536], [1600, 704],[704, 1600], [1664, 640],[640, 1664], [1728, 576],[576, 1728]] ASPECT_1088 = [[1088, 1088], [1152, 1024],[1024, 1152], [1216, 960],[960, 1216], [1280, 896],[896, 1280], [1344, 832],[832, 1344], [1408, 768],[768, 1408], [1472, 704],[704, 1472], [1536, 640],[640, 1536], [1600, 576],[576, 1600], [1664, 512],[512, 1664]] ASPECT_832 = [[832, 832], [896, 768], [768, 896], [960, 704], [704, 960], [1024, 640], [640, 1024], [1152, 576], [576, 1152], [1280, 512], [512, 1280], [1344, 512], [512, 1344], [1408, 448], [448, 1408], [1472, 448], [448, 1472], [1536, 384], [384, 1536], [1600, 384], [384, 1600]] ASPECT_896 = [[896, 896], [960, 832], [832, 960], [1024, 768], [768, 1024], [1088, 704], [704, 1088], [1152, 704], [704, 1152], [1216, 640], [640, 1216], [1280, 640], [640, 1280], [1344, 576], [576, 1344], [1408, 576], [576, 1408], [1472, 512], [512, 1472], [1536, 512], [512, 1536], [1600, 448], [448, 1600], [1664, 448], [448, 1664]] ASPECT_960 = [[960, 960], [1024, 896],[896, 1024], [1088, 832],[832, 1088], [1152, 768],[768, 1152], [1216, 704],[704, 1216], [1280, 640],[640, 1280], [1344, 576],[576, 1344], [1408, 512],[512, 1408], [1472, 448],[448, 1472], [1536, 384],[384, 1536]] ASPECT_1024 = [[1024, 1024], [1088, 960], [960, 1088], [1152, 896], [896, 1152], [1216, 832], [832, 1216], [1344, 768], [768, 1344], [1472, 704], [704, 1472], [1600, 640], [640, 1600], [1728, 576], [576, 1728], [1792, 576], [576, 1792]] ASPECT_768 = [[768,768], # 589824 1:1 [896,640],[640,896], # 573440 1.4:1 [832,704],[704,832], # 585728 1.181:1 [960,576],[576,960], # 552960 1.6:1 [1024,576],[576,1024], # 524288 1.778:1 [1088,512],[512,1088], # 497664 2.125:1 [1152,512],[512,1152], # 589824 2.25:1 [1216,448],[448,1216], # 552960 2.714:1 [1280,448],[448,1280], # 573440 2.857:1 [1344,384],[384,1344], # 518400 3.5:1 [1408,384],[384,1408], # 540672 3.667:1 [1472,320],[320,1472], # 470400 4.6:1 [1536,320],[320,1536], # 491520 4.8:1 ] ASPECT_704 = [[704,704], # 501,376 1:1 [768,640],[640,768], # 491,520 1.2:1 [832,576],[576,832], # 458,752 1.444:1 [896,512],[512,896], # 458,752 1.75:1 [960,512],[512,960], # 491,520 1.875:1 [1024,448],[448,1024], # 458,752 2.286:1 [1088,448],[448,1088], # 487,424 2.429:1 [1152,384],[384,1152], # 442,368 3:1 [1216,384],[384,1216], # 466,944 3.125:1 [1280,384],[384,1280], # 491,520 3.333:1 [1280,320],[320,1280], # 409,600 4:1 [1408,320],[320,1408], # 450,560 4.4:1 [1536,320],[320,1536], # 491,520 4.8:1 ] ASPECT_640 = [[640,640], # 409600 1:1 [704,576],[576,704], # 405504 1.25:1 [768,512],[512,768], # 393216 1.5:1 [896,448],[448,896], # 401408 2:1 [1024,384],[384,1024], # 393216 2.667:1 [1280,320],[320,1280], # 409600 4:1 [1408,256],[256,1408], # 360448 5.5:1 [1472,256],[256,1472], # 376832 5.75:1 [1536,256],[256,1536], # 393216 6:1 [1600,256],[256,1600], # 409600 6.25:1 ] ASPECT_576 = [[576,576], # 331776 1:1 [640,512],[512,640], # 327680 1.25:1 [640,448],[448,640], # 286720 1.4286:1 [704,448],[448,704], # 314928 1.5625:1 [832,384],[384,832], # 317440 2.1667:1 [1024,320],[320,1024], # 327680 3.2:1 [1280,256],[256,1280], # 327680 5:1 ] ASPECT_512 = [[512,512], # 262144 1:1 [576,448],[448,576], # 258048 1.29:1 [640,384],[384,640], # 245760 1.667:1 [768,320],[320,768], # 245760 2.4:1 [832,256],[256,832], # 212992 3.25:1 [896,256],[256,896], # 229376 3.5:1 [960,256],[256,960], # 245760 3.75:1 [1024,256],[256,1024], # 245760 4:1 ] ASPECT_448 = [[448,448], # 200704 1:1 [512,384],[384,512], # 196608 1.33:1 [576,320],[320,576], # 184320 1.8:1 [768,256],[256,768], # 196608 3:1 ] ASPECT_384 = [[384,384], # 147456 1:1 [448,320],[320,448], # 143360 1.4:1 [576,256],[256,576], # 147456 2.25:1 [768,192],[192,768], # 147456 4:1 ] ASPECT_320 = [[320,320], # 102400 1:1 [384,256],[256,384], # 98304 1.5:1 [512,192],[192,512], # 98304 2.67:1 ] ASPECT_256 = [[256,256], # 65536 1:1 [320,192],[192,320], # 61440 1.67:1 [512,128],[128,512], # 65536 4:1 ] #failsafe aspects ASPECTS = ASPECT_512 def get_aspect_buckets(resolution,mode=''): if resolution < 256: raise ValueError("Resolution must be at least 512") try: rounded_resolution = int(resolution / 64) * 64 print(f" {bcolors.WARNING} Rounded resolution to: {rounded_resolution}{bcolors.ENDC}") all_image_sizes = __get_all_aspects() if mode == 'MJ': #truncate to the first 3 resolutions all_image_sizes = [x[0:3] for x in all_image_sizes] aspects = next(filter(lambda sizes: sizes[0][0]==rounded_resolution, all_image_sizes), None) ASPECTS = aspects #print(aspects) return aspects except Exception as e: print(f" {bcolors.FAIL} *** Could not find selected resolution: {rounded_resolution}{bcolors.ENDC}") raise e def __get_all_aspects(): return [ASPECT_256, ASPECT_320, ASPECT_384, ASPECT_448, ASPECT_512, ASPECT_576, ASPECT_640, ASPECT_704, ASPECT_768,ASPECT_832,ASPECT_896,ASPECT_960,ASPECT_1024,ASPECT_1088,ASPECT_1152,ASPECT_1216,ASPECT_1280,ASPECT_1344,ASPECT_1408,ASPECT_1472,ASPECT_1536,ASPECT_1600,ASPECT_1664,ASPECT_1728,ASPECT_1792,ASPECT_1856,ASPECT_1920,ASPECT_1984,ASPECT_2048] class AutoBucketing(Dataset): def __init__(self, concepts_list, tokenizer=None, flip_p=0.0, repeats=1, debug_level=0, batch_size=1, set='val', resolution=512, center_crop=False, use_image_names_as_captions=True, shuffle_captions=False, add_class_images_to_dataset=None, balance_datasets=False, crop_jitter=20, with_prior_loss=False, use_text_files_as_captions=False, aspect_mode='dynamic', action_preference='dynamic', seed=555, model_variant='base', extra_module=None, mask_prompts=None, load_mask=False, ): self.debug_level = debug_level self.resolution = resolution self.center_crop = center_crop self.tokenizer = tokenizer self.batch_size = batch_size self.concepts_list = concepts_list self.use_image_names_as_captions = use_image_names_as_captions self.shuffle_captions = shuffle_captions self.num_train_images = 0 self.num_reg_images = 0 self.image_train_items = [] self.image_reg_items = [] self.add_class_images_to_dataset = add_class_images_to_dataset self.balance_datasets = balance_datasets self.crop_jitter = crop_jitter self.with_prior_loss = with_prior_loss self.use_text_files_as_captions = use_text_files_as_captions self.aspect_mode = aspect_mode self.action_preference = action_preference self.model_variant = model_variant self.extra_module = extra_module self.image_transforms = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) self.mask_transforms = transforms.Compose( [ transforms.ToTensor(), ] ) self.depth_image_transforms = transforms.Compose( [ transforms.ToTensor(), ] ) self.seed = seed #shared_dataloader = None print(f" {bcolors.WARNING}Creating Auto Bucketing Dataloader{bcolors.ENDC}") shared_dataloader = DataLoaderMultiAspect(concepts_list, debug_level=debug_level, resolution=self.resolution, seed=self.seed, batch_size=self.batch_size, flip_p=flip_p, use_image_names_as_captions=self.use_image_names_as_captions, add_class_images_to_dataset=self.add_class_images_to_dataset, balance_datasets=self.balance_datasets, with_prior_loss=self.with_prior_loss, use_text_files_as_captions=self.use_text_files_as_captions, aspect_mode=self.aspect_mode, action_preference=self.action_preference, model_variant=self.model_variant, extra_module=self.extra_module, mask_prompts=mask_prompts, load_mask=load_mask, ) #print(self.image_train_items) if self.with_prior_loss and self.add_class_images_to_dataset == False: self.image_train_items, self.class_train_items = shared_dataloader.get_all_images() self.num_train_images = self.num_train_images + len(self.image_train_items) self.num_reg_images = self.num_reg_images + len(self.class_train_items) self._length = max(max(math.trunc(self.num_train_images * repeats), batch_size),math.trunc(self.num_reg_images * repeats), batch_size) - self.num_train_images % self.batch_size self.num_train_images = self.num_train_images + self.num_reg_images else: self.image_train_items = shared_dataloader.get_all_images() self.num_train_images = self.num_train_images + len(self.image_train_items) self._length = max(math.trunc(self.num_train_images * repeats), batch_size) - self.num_train_images % self.batch_size print() print(f" {bcolors.WARNING} ** Validation Set: {set}, steps: {self._length / batch_size:.0f}, repeats: {repeats} {bcolors.ENDC}") print() def __len__(self): return self._length def __getitem__(self, i): idx = i % self.num_train_images #print(idx) image_train_item = self.image_train_items[idx] example = self.__get_image_for_trainer(image_train_item,debug_level=self.debug_level) if self.with_prior_loss and self.add_class_images_to_dataset == False: idx = i % self.num_reg_images class_train_item = self.class_train_items[idx] example_class = self.__get_image_for_trainer(class_train_item,debug_level=self.debug_level,class_img=True) example= {**example, **example_class} #print the tensor shape #print(example['instance_images'].shape) #print(example.keys()) return example def normalize8(self,I): mn = I.min() mx = I.max() mx -= mn I = ((I - mn)/mx) * 255 return I.astype(np.uint8) def __get_image_for_trainer(self,image_train_item,debug_level=0,class_img=False): example = {} save = debug_level > 2 if class_img==False: image_train_tmp = image_train_item.hydrate(crop=False, save=0, crop_jitter=self.crop_jitter) image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB") instance_prompt = image_train_tmp.caption if self.shuffle_captions: caption_parts = instance_prompt.split(",") random.shuffle(caption_parts) instance_prompt = ",".join(caption_parts) example["instance_images"] = self.image_transforms(image_train_tmp_image) if image_train_tmp.mask is not None: image_train_tmp_mask = Image.fromarray(self.normalize8(image_train_tmp.mask)).convert("L") example["mask"] = self.mask_transforms(image_train_tmp_mask) if self.model_variant == 'depth2img': image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L") example["instance_depth_images"] = self.depth_image_transforms(image_train_tmp_depth) #print(instance_prompt) example["instance_prompt_ids"] = self.tokenizer( instance_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids image_train_item.self_destruct() return example if class_img==True: image_train_tmp = image_train_item.hydrate(crop=False, save=4, crop_jitter=self.crop_jitter) image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB") if self.model_variant == 'depth2img': image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L") example["class_depth_images"] = self.depth_image_transforms(image_train_tmp_depth) example["class_images"] = self.image_transforms(image_train_tmp_image) example["class_prompt_ids"] = self.tokenizer( image_train_tmp.caption, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids image_train_item.self_destruct() return example _RANDOM_TRIM = 0.04 class ImageTrainItem(): """ image: Image mask: Image extra: Image identifier: caption, target_aspect: (width, height), pathname: path to image file flip_p: probability of flipping image (0.0 to 1.0) """ def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0, model_variant='base', load_mask=False): self.caption = caption self.target_wh = target_wh self.pathname = pathname self.mask_pathname = os.path.splitext(pathname)[0] + "-masklabel.png" self.depth_pathname = os.path.splitext(pathname)[0] + "-depth.png" self.flip_p = flip_p self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.cropped_img = None self.model_variant = model_variant self.load_mask=load_mask self.is_dupe = [] self.variant_warning = False self.image = image self.mask = mask self.extra = extra def self_destruct(self): self.image = None self.mask = None self.extra = None self.cropped_img = None self.is_dupe.append(1) def load_image(self, pathname, crop, jitter_amount, flip): if len(self.is_dupe) > 0: self.flip = transforms.RandomHorizontalFlip(p=1.0 if flip else 0.0) image = Image.open(pathname).convert('RGB') width, height = image.size if crop: cropped_img = self.__autocrop(image) image = cropped_img.resize((512, 512), resample=Image.Resampling.LANCZOS) else: width, height = image.size if self.target_wh[0] == self.target_wh[1]: if width > height: left = random.randint(0, width - height) image = image.crop((left, 0, height + left, height)) width = height elif height > width: top = random.randint(0, height - width) image = image.crop((0, top, width, width + top)) height = width elif width > self.target_wh[0]: slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0]) slicew_ratio = random.random() left = int(slice * slicew_ratio) right = width - int(slice * (1 - slicew_ratio)) sliceh_ratio = random.random() top = int(slice * sliceh_ratio) bottom = height - int(slice * (1 - sliceh_ratio)) image = image.crop((left, top, right, bottom)) else: image_aspect = width / height target_aspect = self.target_wh[0] / self.target_wh[1] if image_aspect > target_aspect: new_width = int(height * target_aspect) jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0) left = jitter_amount right = left + new_width image = image.crop((left, 0, right, height)) else: new_height = int(width / target_aspect) jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0) top = jitter_amount bottom = top + new_height image = image.crop((0, top, width, bottom)) # LAZCOS resample image = image.resize(self.target_wh, resample=Image.Resampling.LANCZOS) # print the pixel count of the image # print path to image file # print(self.pathname) # print(self.image.size[0] * self.image.size[1]) image = self.flip(image) return image def hydrate(self, crop=False, save=False, crop_jitter=20): """ crop: hard center crop to 512x512 save: save the cropped image to disk, for manual inspection of resize/crop crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality """ if self.image is None: chance = float(len(self.is_dupe)) / 10.0 flip_p = self.flip_p + chance if chance < 1.0 else 1.0 flip = random.uniform(0, 1) < flip_p if len(self.is_dupe) > 0: crop_jitter = crop_jitter + (len(self.is_dupe) * 10) if crop_jitter < 50 else 50 jitter_amount = random.randint(0, crop_jitter) self.image = self.load_image(self.pathname, crop, jitter_amount, flip) if self.model_variant == "inpainting" or self.load_mask: if os.path.exists(self.mask_pathname) and self.load_mask: self.mask = self.load_image(self.mask_pathname, crop, jitter_amount, flip) else: if self.variant_warning == False: print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}") self.variant_warning = True self.mask = Image.new('RGB', self.image.size, color="white").convert("L") if self.model_variant == "depth2img": if os.path.exists(self.depth_pathname): self.extra = self.load_image(self.depth_pathname, crop, jitter_amount, flip) else: if self.variant_warning == False: print(f" {bcolors.FAIL} ** Warning: No depth found for an image, using an empty depth but make sure you're training the right model variant.{bcolors.ENDC}") self.variant_warning = True self.extra = Image.new('RGB', self.image.size, color="white").convert("L") if type(self.image) is not np.ndarray: if save: base_name = os.path.basename(self.pathname) if not os.path.exists("test/output"): os.makedirs("test/output") self.image.save(f"test/output/{base_name}") self.image = np.array(self.image).astype(np.uint8) self.image = (self.image / 127.5 - 1.0).astype(np.float32) if self.mask is not None and type(self.mask) is not np.ndarray: self.mask = np.array(self.mask).astype(np.uint8) self.mask = (self.mask / 255.0).astype(np.float32) if self.extra is not None and type(self.extra) is not np.ndarray: self.extra = np.array(self.extra).astype(np.uint8) self.extra = (self.extra / 255.0).astype(np.float32) #print(self.image.shape) return self class CachedLatentsDataset(Dataset): #stores paths and loads latents on the fly def __init__(self, cache_paths=(),batch_size=None,tokenizer=None,text_encoder=None,dtype=None,model_variant='base',shuffle_per_epoch=False,args=None): self.cache_paths = cache_paths self.tokenizer = tokenizer self.args = args self.text_encoder = text_encoder #get text encoder device text_encoder_device = next(self.text_encoder.parameters()).device self.empty_batch = [self.tokenizer('',padding="do_not_pad",truncation=True,max_length=self.tokenizer.model_max_length,).input_ids for i in range(batch_size)] #handle text encoder for empty tokens if self.args.train_text_encoder != True: self.empty_tokens = tokenizer.pad({"input_ids": self.empty_batch},padding="max_length",max_length=tokenizer.model_max_length,return_tensors="pt",).to(text_encoder_device).input_ids self.empty_tokens.to(text_encoder_device, dtype=dtype) self.empty_tokens = self.text_encoder(self.empty_tokens)[0] else: self.empty_tokens = tokenizer.pad({"input_ids": self.empty_batch},padding="max_length",max_length=tokenizer.model_max_length,return_tensors="pt",).input_ids self.empty_tokens.to(text_encoder_device, dtype=dtype) self.conditional_dropout = args.conditional_dropout self.conditional_indexes = [] self.model_variant = model_variant self.shuffle_per_epoch = shuffle_per_epoch def __len__(self): return len(self.cache_paths) def __getitem__(self, index): if index == 0: if self.shuffle_per_epoch == True: self.cache_paths = tuple(random.sample(self.cache_paths, len(self.cache_paths))) if len(self.cache_paths) > 1: possible_indexes_extension = None possible_indexes = list(range(0,len(self.cache_paths))) #conditional dropout is a percentage of images to drop from the total cache_paths if self.conditional_dropout != None: if len(self.conditional_indexes) == 0: self.conditional_indexes = random.sample(possible_indexes, k=int(math.ceil(len(possible_indexes)*self.conditional_dropout))) else: #pick indexes from the remaining possible indexes possible_indexes_extension = [i for i in possible_indexes if i not in self.conditional_indexes] #duplicate all values in possible_indexes_extension possible_indexes_extension = possible_indexes_extension + possible_indexes_extension possible_indexes_extension = possible_indexes_extension + self.conditional_indexes self.conditional_indexes = random.sample(possible_indexes_extension, k=int(math.ceil(len(possible_indexes)*self.conditional_dropout))) #check for duplicates in conditional_indexes values if len(self.conditional_indexes) != len(set(self.conditional_indexes)): #remove duplicates self.conditional_indexes_non_dupe = list(set(self.conditional_indexes)) #add a random value from possible_indexes_extension for each duplicate for i in range(len(self.conditional_indexes) - len(self.conditional_indexes_non_dupe)): while True: random_value = random.choice(possible_indexes_extension) if random_value not in self.conditional_indexes_non_dupe: self.conditional_indexes_non_dupe.append(random_value) break self.conditional_indexes = self.conditional_indexes_non_dupe self.cache = torch.load(self.cache_paths[index]) self.latents = self.cache.latents_cache[0] self.tokens = self.cache.tokens_cache[0] self.extra_cache = None self.mask_cache = None if self.cache.mask_cache is not None: self.mask_cache = self.cache.mask_cache[0] self.mask_mean_cache = None if self.cache.mask_mean_cache is not None: self.mask_mean_cache = self.cache.mask_mean_cache[0] if index in self.conditional_indexes: self.text_encoder = self.empty_tokens else: self.text_encoder = self.cache.text_encoder_cache[0] if self.model_variant != 'base': self.extra_cache = self.cache.extra_cache[0] del self.cache return self.latents, self.text_encoder, self.mask_cache, self.mask_mean_cache, self.extra_cache, self.tokens def add_pt_cache(self, cache_path): if len(self.cache_paths) == 0: self.cache_paths = (cache_path,) else: self.cache_paths += (cache_path,) class LatentsDataset(Dataset): def __init__(self, latents_cache=None, text_encoder_cache=None, mask_cache=None, mask_mean_cache=None, extra_cache=None,tokens_cache=None): self.latents_cache = latents_cache self.text_encoder_cache = text_encoder_cache self.mask_cache = mask_cache self.mask_mean_cache = mask_mean_cache self.extra_cache = extra_cache self.tokens_cache = tokens_cache def add_latent(self, latent, text_encoder, cached_mask, cached_extra, tokens_cache): self.latents_cache.append(latent) self.text_encoder_cache.append(text_encoder) self.mask_cache.append(cached_mask) self.mask_mean_cache.append(None if cached_mask is None else cached_mask.mean()) self.extra_cache.append(cached_extra) self.tokens_cache.append(tokens_cache) def __len__(self): return len(self.latents_cache) def __getitem__(self, index): return self.latents_cache[index], self.text_encoder_cache[index], self.mask_cache[index], self.mask_mean_cache[index], self.extra_cache[index], self.tokens_cache[index] class DataLoaderMultiAspect(): """ Data loader for multi-aspect-ratio training and bucketing data_root: root folder of training data batch_size: number of images per batch flip_p: probability of flipping image horizontally (i.e. 0-0.5) """ def __init__( self, concept_list, seed=555, debug_level=0, resolution=512, batch_size=1, flip_p=0.0, use_image_names_as_captions=True, add_class_images_to_dataset=False, balance_datasets=False, with_prior_loss=False, use_text_files_as_captions=False, aspect_mode='dynamic', action_preference='add', model_variant='base', extra_module=None, mask_prompts=None, load_mask=False, ): self.resolution = resolution self.debug_level = debug_level self.flip_p = flip_p self.use_image_names_as_captions = use_image_names_as_captions self.balance_datasets = balance_datasets self.with_prior_loss = with_prior_loss self.add_class_images_to_dataset = add_class_images_to_dataset self.use_text_files_as_captions = use_text_files_as_captions self.aspect_mode = aspect_mode self.action_preference = action_preference self.seed = seed self.model_variant = model_variant self.extra_module = extra_module self.load_mask = load_mask prepared_train_data = [] self.aspects = get_aspect_buckets(resolution) #print(f"* DLMA resolution {resolution}, buckets: {self.aspects}") #process sub directories flag print(f" {bcolors.WARNING} Preloading images...{bcolors.ENDC}") if balance_datasets: print(f" {bcolors.WARNING} Balancing datasets...{bcolors.ENDC}") #get the concept with the least number of images in instance_data_dir for concept in concept_list: count = 0 if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == 1: tot = 0 for root, dirs, files in os.walk(concept['instance_data_dir']): tot += len(files) count = tot else: count = len(os.listdir(concept['instance_data_dir'])) else: count = len(os.listdir(concept['instance_data_dir'])) print(f"{concept['instance_data_dir']} has count of {count}") concept['count'] = count min_concept = min(concept_list, key=lambda x: x['count']) #get the number of images in the concept with the least number of images min_concept_num_images = min_concept['count'] print(" Min concept: ",min_concept['instance_data_dir']," with ",min_concept_num_images," images") balance_cocnept_list = [] for concept in concept_list: #if concept has a key do not balance it if 'do_not_balance' in concept: if concept['do_not_balance'] == True: balance_cocnept_list.append(-1) else: balance_cocnept_list.append(min_concept_num_images) else: balance_cocnept_list.append(min_concept_num_images) for concept in concept_list: if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == True: use_sub_dirs = True else: use_sub_dirs = False else: use_sub_dirs = False self.image_paths = [] #self.class_image_paths = [] min_concept_num_images = None if balance_datasets: min_concept_num_images = balance_cocnept_list[concept_list.index(concept)] data_root = concept['instance_data_dir'] data_root_class = concept['class_data_dir'] concept_prompt = concept['instance_prompt'] concept_class_prompt = concept['class_prompt'] if 'flip_p' in concept.keys(): flip_p = concept['flip_p'] if flip_p == '': flip_p = 0.0 else: flip_p = float(flip_p) self.__recurse_data_root(self=self, recurse_root=data_root,use_sub_dirs=use_sub_dirs) random.Random(self.seed).shuffle(self.image_paths) if self.model_variant == 'depth2img': print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}") self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) # ImageTrainItem[] if add_class_images_to_dataset: self.image_paths = [] self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs) random.Random(self.seed).shuffle(self.image_paths) use_image_names_as_captions = False prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) # ImageTrainItem[] self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) if self.with_prior_loss and add_class_images_to_dataset == False: self.class_image_caption_pairs = [] for concept in concept_list: self.class_images_path = [] data_root_class = concept['class_data_dir'] concept_class_prompt = concept['class_prompt'] self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs,class_images=True) random.Random(seed).shuffle(self.image_paths) if self.model_variant == 'depth2img': print(f" {bcolors.WARNING} ** Depth2Img To Process Class Dataset{bcolors.ENDC}") self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) use_image_names_as_captions = False self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) self.class_image_caption_pairs = self.__bucketize_images(self.class_image_caption_pairs, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) if mask_prompts is not None: print(f" {bcolors.WARNING} Checking and generating missing masks...{bcolors.ENDC}") clip_seg = ClipSeg() clip_seg.mask_images(self.image_paths, mask_prompts) del clip_seg if debug_level > 0: print(f" * DLMA Example: {self.image_caption_pairs[0]} images") #print the length of image_caption_pairs print(f" {bcolors.WARNING} Number of image-caption pairs: {len(self.image_caption_pairs)}{bcolors.ENDC}") if len(self.image_caption_pairs) == 0: raise Exception("All the buckets are empty. Please check your data or reduce the batch size.") def get_all_images(self): if self.with_prior_loss == False: return self.image_caption_pairs else: return self.image_caption_pairs, self.class_image_caption_pairs def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,use_text_files_as_captions=False): """ Create ImageTrainItem objects with metadata for hydration later """ decorated_image_train_items = [] for pathname in image_paths: identifier = concept if use_image_names_as_captions: caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] identifier = caption_from_filename if use_text_files_as_captions: txt_file_path = os.path.splitext(pathname)[0] + ".txt" if os.path.exists(txt_file_path): try: with open(txt_file_path, 'r',encoding='utf-8',errors='ignore') as f: identifier = f.readline().rstrip() f.close() if len(identifier) < 1: raise ValueError(f" *** Could not find valid text in: {txt_file_path}") except Exception as e: print(f" {bcolors.FAIL} *** Error reading {txt_file_path} to get caption, falling back to filename{bcolors.ENDC}") print(e) identifier = caption_from_filename pass #print("identifier: ",identifier) image = Image.open(pathname) width, height = image.size image_aspect = width / height target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask) decorated_image_train_items.append(image_train_item) return decorated_image_train_items @staticmethod def __bucketize_images(prepared_train_data: list, batch_size=1, debug_level=0,aspect_mode='dynamic',action_preference='add'): """ Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder """ # TODO: this is not terribly efficient but at least linear time buckets = {} for image_caption_pair in prepared_train_data: target_wh = image_caption_pair.target_wh if (target_wh[0],target_wh[1]) not in buckets: buckets[(target_wh[0],target_wh[1])] = [] buckets[(target_wh[0],target_wh[1])].append(image_caption_pair) print(f" ** Number of buckets: {len(buckets)}") for bucket in buckets: bucket_len = len(buckets[bucket]) #real_len = len(buckets[bucket])+1 #print(real_len) truncate_amount = bucket_len % batch_size add_amount = batch_size - bucket_len % batch_size action = None #print(f" ** Bucket {bucket} has {bucket_len} images") if aspect_mode == 'dynamic': if batch_size == bucket_len: action = None elif add_amount < truncate_amount and add_amount != 0 and add_amount != batch_size or truncate_amount == 0: action = 'add' #print(f'should add {add_amount}') elif truncate_amount < add_amount and truncate_amount != 0 and truncate_amount != batch_size and batch_size < bucket_len: #print(f'should truncate {truncate_amount}') action = 'truncate' #truncate the bucket elif truncate_amount == add_amount: if action_preference == 'add': action = 'add' elif action_preference == 'truncate': action = 'truncate' elif batch_size > bucket_len: action = 'add' elif aspect_mode == 'add': action = 'add' elif aspect_mode == 'truncate': action = 'truncate' if action == None: action = None #print('no need to add or truncate') if action == None: #print('test') current_bucket_size = bucket_len print(f" ** Bucket {bucket} found {bucket_len}, nice!") elif action == 'add': #copy the bucket shuffleBucket = random.sample(buckets[bucket], bucket_len) #add the images to the bucket current_bucket_size = bucket_len truncate_count = (bucket_len) % batch_size #how many images to add to the bucket to fill the batch addAmount = batch_size - truncate_count if addAmount != batch_size: added=0 while added != addAmount: randomIndex = random.randint(0,len(shuffleBucket)-1) #print(str(randomIndex)) buckets[bucket].append(shuffleBucket[randomIndex]) added+=1 print(f" ** Bucket {bucket} found {bucket_len} images, will {bcolors.OKCYAN}duplicate {added} images{bcolors.ENDC} due to batch size {bcolors.WARNING}{batch_size}{bcolors.ENDC}") else: print(f" ** Bucket {bucket} found {bucket_len}, {bcolors.OKGREEN}nice!{bcolors.ENDC}") elif action == 'truncate': truncate_count = (bucket_len) % batch_size current_bucket_size = bucket_len buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] print(f" ** Bucket {bucket} found {bucket_len} images, will {bcolors.FAIL}drop {truncate_count} images{bcolors.ENDC} due to batch size {bcolors.WARNING}{batch_size}{bcolors.ENDC}") # flatten the buckets image_caption_pairs = [] for bucket in buckets: image_caption_pairs.extend(buckets[bucket]) return image_caption_pairs @staticmethod def __recurse_data_root(self, recurse_root,use_sub_dirs=True,class_images=False): progress_bar = tqdm(os.listdir(recurse_root), desc=f" {bcolors.WARNING} ** Processing {recurse_root}{bcolors.ENDC}") for f in os.listdir(recurse_root): current = os.path.join(recurse_root, f) if os.path.isfile(current): ext = os.path.splitext(f)[1].lower() if '-depth' in f or '-masklabel' in f: progress_bar.update(1) continue if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']: #try to open the file to make sure it's a valid image try: img = Image.open(current) except: print(f" ** Skipping {current} because it failed to open, please check the file") progress_bar.update(1) continue del img if class_images == False: self.image_paths.append(current) else: self.class_images_path.append(current) progress_bar.update(1) if use_sub_dirs: sub_dirs = [] for d in os.listdir(recurse_root): current = os.path.join(recurse_root, d) if os.path.isdir(current): sub_dirs.append(current) for dir in sub_dirs: self.__recurse_data_root(self=self, recurse_root=dir) class NormalDataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. It pre-processes the images and the tokenizes prompts. """ def __init__( self, concepts_list, tokenizer, with_prior_preservation=True, size=512, center_crop=False, num_class_images=None, use_image_names_as_captions=False, shuffle_captions=False, repeats=1, use_text_files_as_captions=False, seed=555, model_variant='base', extra_module=None, mask_prompts=None, load_mask=None, ): self.use_image_names_as_captions = use_image_names_as_captions self.shuffle_captions = shuffle_captions self.size = size self.center_crop = center_crop self.tokenizer = tokenizer self.with_prior_preservation = with_prior_preservation self.use_text_files_as_captions = use_text_files_as_captions self.image_paths = [] self.class_images_path = [] self.seed = seed self.model_variant = model_variant self.variant_warning = False self.vae_scale_factor = None self.load_mask = load_mask for concept in concepts_list: if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == True: use_sub_dirs = True else: use_sub_dirs = False else: use_sub_dirs = False for i in range(repeats): self.__recurse_data_root(self, concept,use_sub_dirs=use_sub_dirs) if with_prior_preservation: for i in range(repeats): self.__recurse_data_root(self, concept,use_sub_dirs=False,class_images=True) if mask_prompts is not None: print(f" {bcolors.WARNING} Checking and generating missing masks{bcolors.ENDC}") clip_seg = ClipSeg() clip_seg.mask_images(self.image_paths, mask_prompts) del clip_seg random.Random(seed).shuffle(self.image_paths) self.num_instance_images = len(self.image_paths) self._length = self.num_instance_images self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) if self.model_variant == 'depth2img': print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}") self.vae_scale_factor = extra_module.depth_images(self.image_paths) if self.with_prior_preservation: print(f" {bcolors.WARNING} ** Loading Depth2Img Class Processing{bcolors.ENDC}") extra_module.depth_images(self.class_images_path) print(f" {bcolors.WARNING} ** Dataset length: {self._length}, {int(self.num_instance_images / repeats)} images using {repeats} repeats{bcolors.ENDC}") self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) self.mask_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), ]) self.depth_image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), ] ) @staticmethod def __recurse_data_root(self, recurse_root,use_sub_dirs=True,class_images=False): #if recurse root is a dict if isinstance(recurse_root, dict): if class_images == True: #print(f" {bcolors.WARNING} ** Processing class images: {recurse_root['class_data_dir']}{bcolors.ENDC}") concept_token = recurse_root['class_prompt'] data = recurse_root['class_data_dir'] else: #print(f" {bcolors.WARNING} ** Processing instance images: {recurse_root['instance_data_dir']}{bcolors.ENDC}") concept_token = recurse_root['instance_prompt'] data = recurse_root['instance_data_dir'] else: concept_token = None #progress bar progress_bar = tqdm(os.listdir(data), desc=f" {bcolors.WARNING} ** Processing {data}{bcolors.ENDC}") for f in os.listdir(data): current = os.path.join(data, f) if os.path.isfile(current): if '-depth' in f or '-masklabel' in f: continue ext = os.path.splitext(f)[1].lower() if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']: try: img = Image.open(current) except: print(f" ** Skipping {current} because it failed to open, please check the file") progress_bar.update(1) continue del img if class_images == False: self.image_paths.append([current,concept_token]) else: self.class_images_path.append([current,concept_token]) progress_bar.update(1) if use_sub_dirs: sub_dirs = [] for d in os.listdir(data): current = os.path.join(data, d) if os.path.isdir(current): sub_dirs.append(current) for dir in sub_dirs: if class_images == False: self.__recurse_data_root(self=self, recurse_root={'instance_data_dir' : dir, 'instance_prompt' : concept_token}) else: self.__recurse_data_root(self=self, recurse_root={'class_data_dir' : dir, 'class_prompt' : concept_token}) def __len__(self): return self._length def __getitem__(self, index): example = {} instance_path, instance_prompt = self.image_paths[index % self.num_instance_images] og_prompt = instance_prompt instance_image = Image.open(instance_path) if self.model_variant == "inpainting" or self.load_mask: mask_pathname = os.path.splitext(instance_path)[0] + "-masklabel.png" if os.path.exists(mask_pathname) and self.load_mask: mask = Image.open(mask_pathname).convert("L") else: if self.variant_warning == False: print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}") self.variant_warning = True size = instance_image.size mask = Image.new('RGB', size, color="white").convert("L") example["mask"] = self.mask_transforms(mask) if self.model_variant == "depth2img": depth_pathname = os.path.splitext(instance_path)[0] + "-depth.png" if os.path.exists(depth_pathname): depth_image = Image.open(depth_pathname).convert("L") else: if self.variant_warning == False: print(f" {bcolors.FAIL} ** Warning: No depth image found for an image, using an empty depth image but make sure you're training the right model variant.{bcolors.ENDC}") self.variant_warning = True size = instance_image.size depth_image = Image.new('RGB', size, color="white").convert("L") example["instance_depth_images"] = self.depth_image_transforms(depth_image) if self.use_image_names_as_captions == True: instance_prompt = str(instance_path).split(os.sep)[-1].split('.')[0].split('_')[0] #else if there's a txt file with the same name as the image, read the caption from there if self.use_text_files_as_captions == True: #if there's a file with the same name as the image, but with a .txt extension, read the caption from there #get the last . in the file name last_dot = str(instance_path).rfind('.') #get the path up to the last dot txt_path = str(instance_path)[:last_dot] + '.txt' #if txt_path exists, read the caption from there if os.path.exists(txt_path): with open(txt_path, encoding='utf-8') as f: instance_prompt = f.readline().rstrip() f.close() if self.shuffle_captions: caption_parts = instance_prompt.split(",") random.shuffle(caption_parts) instance_prompt = ",".join(caption_parts) #print('identifier: ' + instance_prompt) instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) example["instance_prompt_ids"] = self.tokenizer( instance_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids if self.with_prior_preservation: class_path, class_prompt = self.class_images_path[index % self.num_class_images] class_image = Image.open(class_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") if self.model_variant == "inpainting": mask_pathname = os.path.splitext(class_path)[0] + "-masklabel.png" if os.path.exists(mask_pathname): mask = Image.open(mask_pathname).convert("L") else: if self.variant_warning == False: print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}") self.variant_warning = True size = instance_image.size mask = Image.new('RGB', size, color="white").convert("L") example["class_mask"] = self.mask_transforms(mask) if self.model_variant == "depth2img": depth_pathname = os.path.splitext(class_path)[0] + "-depth.png" if os.path.exists(depth_pathname): depth_image = Image.open(depth_pathname) else: if self.variant_warning == False: print(f" {bcolors.FAIL} ** Warning: No depth image found for an image, using an empty depth image but make sure you're training the right model variant.{bcolors.ENDC}") self.variant_warning = True size = instance_image.size depth_image = Image.new('RGB', size, color="white").convert("L") example["class_depth_images"] = self.depth_image_transforms(depth_image) example["class_images"] = self.image_transforms(class_image) example["class_prompt_ids"] = self.tokenizer( class_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids return example class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." def __init__(self, prompt, num_samples): self.prompt = prompt self.num_samples = num_samples def __len__(self): return self.num_samples def __getitem__(self, index): example = {} example["prompt"] = self.prompt example["index"] = index return example