hysts HF staff commited on
Commit
8b149b2
1 Parent(s): 307f13f

Use Uploader to upload models in training time

Browse files

Using two different upload methods was not a good idea.
So, stop using upload method provided by train_dreambooth_lora.py
and use Uploader class in this repo.

Also, to make it easier to port updates for train_dreambooth_lora.py
from the diffusers library, reset changes.

Files changed (3) hide show
  1. train_dreambooth_lora.py +39 -44
  2. trainer.py +7 -0
  3. utils.py +38 -0
train_dreambooth_lora.py CHANGED
@@ -1,8 +1,9 @@
1
  #!/usr/bin/env python
2
- # This file is adapted from https://github.com/huggingface/diffusers/blob/a66f2baeb782e091dde4e1e6394e46f169e5ba58/examples/dreambooth/train_dreambooth_lora.py
3
- # The original license is as below.
4
- #
5
  # coding=utf-8
 
 
 
 
6
  # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
7
  #
8
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -25,6 +26,7 @@ import warnings
25
  from pathlib import Path
26
  from typing import Optional
27
 
 
28
  import torch
29
  import torch.nn.functional as F
30
  import torch.utils.checkpoint
@@ -48,7 +50,7 @@ from diffusers.models.cross_attention import LoRACrossAttnProcessor
48
  from diffusers.optimization import get_scheduler
49
  from diffusers.utils import check_min_version, is_wandb_available
50
  from diffusers.utils.import_utils import is_xformers_available
51
- from huggingface_hub import HfFolder, Repository, create_repo, delete_repo, whoami
52
  from PIL import Image
53
  from torchvision import transforms
54
  from tqdm.auto import tqdm
@@ -61,9 +63,9 @@ check_min_version("0.12.0.dev0")
61
  logger = get_logger(__name__)
62
 
63
 
64
- def save_model_card(repo_name, base_model, instance_prompt, test_prompt="", images=None, repo_folder=""):
65
- img_str = f"Test prompt: {test_prompt}\n" if test_prompt else ""
66
- for i, image in enumerate(images or []):
67
  image.save(os.path.join(repo_folder, f"image_{i}.png"))
68
  img_str += f"![img_{i}](./image_{i}.png)\n"
69
 
@@ -71,7 +73,6 @@ def save_model_card(repo_name, base_model, instance_prompt, test_prompt="", imag
71
  ---
72
  license: creativeml-openrail-m
73
  base_model: {base_model}
74
- instance_prompt: {instance_prompt}
75
  tags:
76
  - stable-diffusion
77
  - stable-diffusion-diffusers
@@ -79,11 +80,11 @@ tags:
79
  - diffusers
80
  inference: true
81
  ---
82
- """
83
  model_card = f"""
84
  # LoRA DreamBooth - {repo_name}
85
 
86
- These are LoRA adaption weights for [{base_model}](https://huggingface.co/{base_model}). The weights were trained on the instance prompt "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.\n
87
  {img_str}
88
  """
89
  with open(os.path.join(repo_folder, "README.md"), "w") as f:
@@ -364,9 +365,6 @@ def parse_args(input_args=None):
364
  parser.add_argument(
365
  "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
366
  )
367
- parser.add_argument("--private_repo", action="store_true")
368
- parser.add_argument("--delete_existing_repo", action="store_true")
369
- parser.add_argument("--upload_to_lora_library", action="store_true")
370
 
371
  if input_args is not None:
372
  args = parser.parse_args(input_args)
@@ -610,17 +608,11 @@ def main(args):
610
  if accelerator.is_main_process:
611
  if args.push_to_hub:
612
  if args.hub_model_id is None:
613
- organization = 'lora-library' if args.upload_to_lora_library else None
614
- repo_name = get_full_repo_name(Path(args.output_dir).name, organization=organization, token=args.hub_token)
615
  else:
616
  repo_name = args.hub_model_id
617
 
618
- if args.delete_existing_repo:
619
- try:
620
- delete_repo(repo_name, token=args.hub_token)
621
- except Exception:
622
- pass
623
- create_repo(repo_name, token=args.hub_token, private=args.private_repo)
624
  repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
625
 
626
  with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
@@ -826,14 +818,21 @@ def main(args):
826
  dirs = os.listdir(args.output_dir)
827
  dirs = [d for d in dirs if d.startswith("checkpoint")]
828
  dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
829
- path = dirs[-1]
830
- accelerator.print(f"Resuming from checkpoint {path}")
831
- accelerator.load_state(os.path.join(args.output_dir, path))
832
- global_step = int(path.split("-")[1])
833
 
834
- resume_global_step = global_step * args.gradient_accumulation_steps
835
- first_epoch = resume_global_step // num_update_steps_per_epoch
836
- resume_step = resume_global_step % num_update_steps_per_epoch
 
 
 
 
 
 
 
 
 
 
837
 
838
  # Only show the progress bar once on each machine.
839
  progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
@@ -943,6 +942,9 @@ def main(args):
943
  images = pipeline(prompt, num_inference_steps=25, generator=generator).images
944
 
945
  for tracker in accelerator.trackers:
 
 
 
946
  if tracker.name == "wandb":
947
  tracker.log(
948
  {
@@ -974,11 +976,15 @@ def main(args):
974
  pipeline.unet.load_attn_procs(args.output_dir)
975
 
976
  # run inference
977
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
978
- prompt = args.num_validation_images * [args.validation_prompt]
979
- images = pipeline(prompt, num_inference_steps=25, generator=generator).images
 
980
 
981
  for tracker in accelerator.trackers:
 
 
 
982
  if tracker.name == "wandb":
983
  tracker.log(
984
  {
@@ -992,23 +998,12 @@ def main(args):
992
  if args.push_to_hub:
993
  save_model_card(
994
  repo_name,
995
- base_model=args.pretrained_model_name_or_path,
996
- instance_prompt=args.instance_prompt,
997
- test_prompt=args.validation_prompt,
998
  images=images,
999
- repo_folder=args.output_dir,
1000
- )
1001
- repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1002
- else:
1003
- repo_name = Path(args.output_dir).name
1004
- save_model_card(
1005
- repo_name,
1006
  base_model=args.pretrained_model_name_or_path,
1007
- instance_prompt=args.instance_prompt,
1008
- test_prompt=args.validation_prompt,
1009
- images=images,
1010
  repo_folder=args.output_dir,
1011
  )
 
1012
 
1013
  accelerator.end_training()
1014
 
 
1
  #!/usr/bin/env python
 
 
 
2
  # coding=utf-8
3
+ #
4
+ # This file is copied from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth_lora.py
5
+ # The original license is as below:
6
+ #
7
  # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
 
26
  from pathlib import Path
27
  from typing import Optional
28
 
29
+ import numpy as np
30
  import torch
31
  import torch.nn.functional as F
32
  import torch.utils.checkpoint
 
50
  from diffusers.optimization import get_scheduler
51
  from diffusers.utils import check_min_version, is_wandb_available
52
  from diffusers.utils.import_utils import is_xformers_available
53
+ from huggingface_hub import HfFolder, Repository, create_repo, whoami
54
  from PIL import Image
55
  from torchvision import transforms
56
  from tqdm.auto import tqdm
 
63
  logger = get_logger(__name__)
64
 
65
 
66
+ def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
67
+ img_str = ""
68
+ for i, image in enumerate(images):
69
  image.save(os.path.join(repo_folder, f"image_{i}.png"))
70
  img_str += f"![img_{i}](./image_{i}.png)\n"
71
 
 
73
  ---
74
  license: creativeml-openrail-m
75
  base_model: {base_model}
 
76
  tags:
77
  - stable-diffusion
78
  - stable-diffusion-diffusers
 
80
  - diffusers
81
  inference: true
82
  ---
83
+ """
84
  model_card = f"""
85
  # LoRA DreamBooth - {repo_name}
86
 
87
+ These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
88
  {img_str}
89
  """
90
  with open(os.path.join(repo_folder, "README.md"), "w") as f:
 
365
  parser.add_argument(
366
  "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
367
  )
 
 
 
368
 
369
  if input_args is not None:
370
  args = parser.parse_args(input_args)
 
608
  if accelerator.is_main_process:
609
  if args.push_to_hub:
610
  if args.hub_model_id is None:
611
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
 
612
  else:
613
  repo_name = args.hub_model_id
614
 
615
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
 
 
 
 
 
616
  repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
617
 
618
  with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
 
818
  dirs = os.listdir(args.output_dir)
819
  dirs = [d for d in dirs if d.startswith("checkpoint")]
820
  dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
821
+ path = dirs[-1] if len(dirs) > 0 else None
 
 
 
822
 
823
+ if path is None:
824
+ accelerator.print(
825
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
826
+ )
827
+ args.resume_from_checkpoint = None
828
+ else:
829
+ accelerator.print(f"Resuming from checkpoint {path}")
830
+ accelerator.load_state(os.path.join(args.output_dir, path))
831
+ global_step = int(path.split("-")[1])
832
+
833
+ resume_global_step = global_step * args.gradient_accumulation_steps
834
+ first_epoch = global_step // num_update_steps_per_epoch
835
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
836
 
837
  # Only show the progress bar once on each machine.
838
  progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
 
942
  images = pipeline(prompt, num_inference_steps=25, generator=generator).images
943
 
944
  for tracker in accelerator.trackers:
945
+ if tracker.name == "tensorboard":
946
+ np_images = np.stack([np.asarray(img) for img in images])
947
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
948
  if tracker.name == "wandb":
949
  tracker.log(
950
  {
 
976
  pipeline.unet.load_attn_procs(args.output_dir)
977
 
978
  # run inference
979
+ if args.validation_prompt and args.num_validation_images > 0:
980
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
981
+ prompt = args.num_validation_images * [args.validation_prompt]
982
+ images = pipeline(prompt, num_inference_steps=25, generator=generator).images
983
 
984
  for tracker in accelerator.trackers:
985
+ if tracker.name == "tensorboard":
986
+ np_images = np.stack([np.asarray(img) for img in images])
987
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
988
  if tracker.name == "wandb":
989
  tracker.log(
990
  {
 
998
  if args.push_to_hub:
999
  save_model_card(
1000
  repo_name,
 
 
 
1001
  images=images,
 
 
 
 
 
 
 
1002
  base_model=args.pretrained_model_name_or_path,
1003
+ prompt=args.instance_prompt,
 
 
1004
  repo_folder=args.output_dir,
1005
  )
1006
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1007
 
1008
  accelerator.end_training()
1009
 
trainer.py CHANGED
@@ -14,6 +14,7 @@ import torch
14
  from huggingface_hub import HfApi
15
 
16
  from app_upload import LoRAModelUploader
 
17
 
18
 
19
  def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
@@ -125,6 +126,12 @@ class Trainer:
125
  command_s = ' '.join(command.split())
126
  f.write(command_s)
127
  subprocess.run(shlex.split(command))
 
 
 
 
 
 
128
  message = 'Training completed!'
129
  print(message)
130
 
 
14
  from huggingface_hub import HfApi
15
 
16
  from app_upload import LoRAModelUploader
17
+ from utils import save_model_card
18
 
19
 
20
  def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
 
126
  command_s = ' '.join(command.split())
127
  f.write(command_s)
128
  subprocess.run(shlex.split(command))
129
+ save_model_card(save_dir=output_dir,
130
+ base_model=base_model,
131
+ instance_prompt=instance_prompt,
132
+ test_prompt=validation_prompt,
133
+ test_image_dir='test_images')
134
+
135
  message = 'Training completed!'
136
  print(message)
137
 
utils.py CHANGED
@@ -18,3 +18,41 @@ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
18
  exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
19
  ]
20
  return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
19
  ]
20
  return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
21
+
22
+
23
+ def save_model_card(
24
+ save_dir: pathlib.Path,
25
+ base_model: str,
26
+ instance_prompt: str,
27
+ test_prompt: str = '',
28
+ test_image_dir: str = '',
29
+ ) -> None:
30
+ image_str = ''
31
+ if test_prompt and test_image_dir:
32
+ image_paths = sorted((save_dir / test_image_dir).glob('*'))
33
+ if image_paths:
34
+ image_str = f'Test prompt: {test_prompt}\n'
35
+ for image_path in image_paths:
36
+ rel_path = image_path.relative_to(save_dir)
37
+ image_str += f'![{image_path.stem}]({rel_path})\n'
38
+
39
+ model_card = f'''---
40
+ license: creativeml-openrail-m
41
+ base_model: {base_model}
42
+ instance_prompt: {instance_prompt}
43
+ tags:
44
+ - stable-diffusion
45
+ - stable-diffusion-diffusers
46
+ - text-to-image
47
+ - diffusers
48
+ inference: true
49
+ ---
50
+ # LoRA DreamBooth - {save_dir.name}
51
+
52
+ These are LoRA adaption weights for [{base_model}](https://huggingface.co/{base_model}). The weights were trained on the instance prompt "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.
53
+
54
+ {image_str}
55
+ '''
56
+
57
+ with open(save_dir / 'README.md', 'w') as f:
58
+ f.write(model_card)