Save test images with train_dreambooth_lora.py
Browse files- train_dreambooth_lora.py +7 -1
train_dreambooth_lora.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding=utf-8
|
3 |
#
|
4 |
-
# This file is
|
5 |
# The original license is as below:
|
6 |
#
|
7 |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
@@ -981,6 +981,12 @@ def main(args):
|
|
981 |
prompt = args.num_validation_images * [args.validation_prompt]
|
982 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
983 |
|
|
|
|
|
|
|
|
|
|
|
|
|
984 |
for tracker in accelerator.trackers:
|
985 |
if tracker.name == "tensorboard":
|
986 |
np_images = np.stack([np.asarray(img) for img in images])
|
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding=utf-8
|
3 |
#
|
4 |
+
# This file is adapted from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth_lora.py
|
5 |
# The original license is as below:
|
6 |
#
|
7 |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
|
981 |
prompt = args.num_validation_images * [args.validation_prompt]
|
982 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
983 |
|
984 |
+
test_image_dir = Path(args.output_dir) / 'test_images'
|
985 |
+
test_image_dir.mkdir()
|
986 |
+
for i, image in enumerate(images):
|
987 |
+
out_path = test_image_dir / f'image_{i}.png'
|
988 |
+
image.save(out_path)
|
989 |
+
|
990 |
for tracker in accelerator.trackers:
|
991 |
if tracker.name == "tensorboard":
|
992 |
np_images = np.stack([np.asarray(img) for img in images])
|