hysts HF staff commited on
Commit
2873c82
1 Parent(s): 8b149b2

Save test images with train_dreambooth_lora.py

Browse files
Files changed (1) hide show
  1. train_dreambooth_lora.py +7 -1
train_dreambooth_lora.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
  #
4
- # This file is copied from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth_lora.py
5
  # The original license is as below:
6
  #
7
  # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
@@ -981,6 +981,12 @@ def main(args):
981
  prompt = args.num_validation_images * [args.validation_prompt]
982
  images = pipeline(prompt, num_inference_steps=25, generator=generator).images
983
 
 
 
 
 
 
 
984
  for tracker in accelerator.trackers:
985
  if tracker.name == "tensorboard":
986
  np_images = np.stack([np.asarray(img) for img in images])
 
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
  #
4
+ # This file is adapted from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth_lora.py
5
  # The original license is as below:
6
  #
7
  # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
 
981
  prompt = args.num_validation_images * [args.validation_prompt]
982
  images = pipeline(prompt, num_inference_steps=25, generator=generator).images
983
 
984
+ test_image_dir = Path(args.output_dir) / 'test_images'
985
+ test_image_dir.mkdir()
986
+ for i, image in enumerate(images):
987
+ out_path = test_image_dir / f'image_{i}.png'
988
+ image.save(out_path)
989
+
990
  for tracker in accelerator.trackers:
991
  if tracker.name == "tensorboard":
992
  np_images = np.stack([np.asarray(img) for img in images])