sayakpaul HF staff commited on
Commit
32c012b
·
verified ·
1 Parent(s): d9bab68

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +96 -0
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ license: other
4
+ license_name: flux-1-dev-non-commercial-license
5
+ license_link: LICENSE.md
6
+ ---
7
+
8
+ > [!NOTE]
9
+ > Contains the NF4 checkpoints (`transformer` and `text_encoder_2`) of [`black-forest-labs/FLUX.1-Fill-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev). Please adhere to the original model licensing!
10
+
11
+ <details>
12
+ <summary>Code</summary>
13
+
14
+ ```py
15
+ from diffusers import DiffusionPipeline, FluxFillPipeline, FluxTransformer2DModel
16
+ import torch
17
+ from transformers import T5EncoderModel
18
+ from diffusers.utils import load_image
19
+ import fire
20
+
21
+
22
+ def load_pipeline(four_bit=False):
23
+ orig_pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
24
+ if four_bit:
25
+ print("Using four bit.")
26
+ transformer = FluxTransformer2DModel.from_pretrained(
27
+ "sayakpaul/FLUX.1-Fill-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
28
+ )
29
+ text_encoder_2 = T5EncoderModel.from_pretrained(
30
+ "sayakpaul/FLUX.1-Fill-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
31
+ )
32
+ pipeline = FluxFillPipeline.from_pipe(
33
+ orig_pipeline, ransformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16
34
+ )
35
+ else:
36
+ transformer = FluxTransformer2DModel.from_pretrained(
37
+ "black-forest-labs/FLUX.1-Fill-dev",
38
+ subfolder="transformer",
39
+ revision="refs/pr/4",
40
+ torch_dtype=torch.bfloat16,
41
+ )
42
+ pipeline = FluxFillPipeline.from_pipe(orig_pipeline, transformer=transformer, torch_dtype=torch.bfloat16)
43
+
44
+ pipeline.enable_model_cpu_offload()
45
+ return pipeline
46
+
47
+
48
+ def load_conditions():
49
+ image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png")
50
+ mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png")
51
+ return image, mask
52
+
53
+
54
+ def main(ckpt_id: str, four_bit: bool = False):
55
+ pipe = load_pipeline(ckpt_id=ckpt_id, four_bit=four_bit)
56
+ image, mask = load_conditions()
57
+ image = pipe(
58
+ prompt="a white paper cup",
59
+ image=image,
60
+ mask_image=mask,
61
+ height=1024,
62
+ width=1024,
63
+ max_sequence_length=512,
64
+ generator=torch.Generator("cpu").manual_seed(0),
65
+ ).images[0]
66
+ filename = "output_" + ckpt_id.split("/")[-1].replace(".", "_")
67
+ filename += "_4bit" if four_bit else ""
68
+ image.save(f"{filename}.png")
69
+
70
+
71
+ if __name__ == "__main__":
72
+ fire.Fire(main)
73
+ ```
74
+
75
+ </details>
76
+
77
+ ## Outputs
78
+
79
+ <table>
80
+ <thead>
81
+ <tr>
82
+ <th>Original</th>
83
+ <th>NF4</th>
84
+ </tr>
85
+ </thead>
86
+ <tbody>
87
+ <tr>
88
+ <td>
89
+ <img src="./assets/output_FLUX_1-Fill-dev.png" alt="Original">
90
+ </td>
91
+ <td>
92
+ <img src="./assets/output_FLUX_1-Fill-dev_4bit.png" alt="NF4">
93
+ </td>
94
+ </tr>
95
+ </tbody>
96
+ </table>