File size: 10,166 Bytes
fa57efb 3b2d1c3 fa57efb 107faf9 fa57efb 3b2d1c3 fa57efb 107faf9 28b6504 fa57efb 3b2d1c3 fa57efb 28b6504 fa57efb 28b6504 fa57efb 28b6504 fa57efb 28b6504 3a4b81e fa57efb 28b6504 fa57efb 28b6504 fa57efb 28b6504 fa57efb 28b6504 fa57efb 1c074f9 fa57efb 1c074f9 fa57efb 1c074f9 fa57efb 1c074f9 fa57efb e13c61b 35dea3e 3a4b81e 35dea3e e13c61b fa57efb 3a4b81e fa57efb 1281794 fa57efb 1281794 fa57efb 1281794 fa57efb 3a4b81e fa57efb 1281794 fa57efb 1281794 fa57efb 1281794 fa57efb 3a4b81e fa57efb 1281794 fa57efb 1281794 fa57efb 1281794 fa57efb 3a4b81e fa57efb 1281794 fa57efb 1281794 fa57efb 1281794 fa57efb 59b294b 28b6504 59b294b fa57efb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
---
license: other
license_name: tencent-hunyuan-community
license_link: https://huggingface.co./Tencent-Hunyuan/HunyuanDiT/blob/main/LICENSE.txt
language:
- en
---
# HunyuanDiT LoRA
Language: **English**
## Instructions
The dependencies and installation are basically the same as the [**base model**](https://huggingface.co./Tencent-Hunyuan/HunyuanDiT-v1.2).
We provide two types of trained LoRA weights for you to test.
Then download the model using the following commands:
```bash
cd HunyuanDiT
# Use the huggingface-cli tool to download the model.
huggingface-cli download Tencent-Hunyuan/HYDiT-LoRA --local-dir ./ckpts/t2i/lora
# Quick start
python sample_t2i.py --prompt "青花瓷风格,一只猫在追蝴蝶" --no-enhance --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain --infer-mode fa
```
## Training
We provide three types of weights for fine-tuning LoRA, `ema`, `module` and `distill`, and you can choose according to the actual effect. By default, we use `ema` weights.
Here is an example for LoRA with HunYuanDiT v1.2, we load the `distill` weights into the main model and perform LoRA fine-tuning through the `resume_module_root=./ckpts/t2i/model/pytorch_model_distill.pt` setting.
If multiple resolution are used, you need to add the `--multireso` and `--reso-step 64 ` parameter.
If you want to train LoRA with HunYuanDiT v1.1, you could add `--use-style-cond`, `--size-cond 1024 1024` and `--beta-end 0.03`.
```bash
model='DiT-g/2' # model type
task_flag="lora_porcelain_ema_rank64" # task flag
resume_module_root=./ckpts/t2i/model/pytorch_model_distill.pt # resume checkpoint
index_file=dataset/porcelain/jsons/porcelain.json # the selected data indices
results_dir=./log_EXP # save root for results
batch_size=1 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=2 # gradient accumulation steps
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=100 # create a ckpt every a few steps.
ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps.
rank=64 # rank of lora
max_training_steps=2000 # Maximum training iteration steps
PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \
--task-flag ${task_flag} \
--model ${model} \
--training-parts lora \
--rank ${rank} \
--resume \
--resume-module-root ${resume_module_root} \
--lr ${lr} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.018 \
--predict-type v_prediction \
--uncond-p 0 \
--uncond-p-t5 0 \
--index-file ${index_file} \
--random-flip \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--ema-dtype fp32 \
--results-dir ${results_dir} \
--ckpt-every ${ckpt_every} \
--max-training-steps ${max_training_steps}\
--ckpt-latest-every ${ckpt_latest_every} \
--log-every 10 \
--deepspeed \
--deepspeed-optimizer \
--use-zero-stage 2 \
--qk-norm \
--rope-img base512 \
--rope-real \
"$@"
```
Recommended parameter settings
| Parameter | Description | Recommended Parameter Value | Note|
|:---------------:|:---------:|:---------------------------------------------------:|:--:|
| `--batch_size` | Training batch size | 1 | Depends on GPU memory|
| `--grad-accu-steps` | Size of gradient accumulation | 2 | - |
| `--rank` | Rank of lora | 64 | Choosing from 8-128|
| `--max-training-steps` | Training steps | 2000 | Depend on training data size, for reference apply 2000 steps on 100 images|
| `--lr` | Learning rate | 0.0001 | - |
## Inference
### Using Gradio
Make sure you have activated the conda environment before running the following command.
> ⚠️ Important Reminder:
> We recommend not using prompt enhance, as it may lead to the disappearance of style words.
```shell
# jade style
# Using Flash Attention for acceleration.
python app/hydit_app.py --infer-mode fa --load-key ema --lora-ckpt ./ckpts/t2i/lora/jade
# You can disable the enhancement model if the GPU memory is insufficient.
# The enhancement will be unavailable until you restart the app without the `--no-enhance` flag.
python app/hydit_app.py --infer-mode fa --no-enhance --load-key ema --lora-ckpt ./ckpts/t2i/lora/jade
# Start with English UI
python app/hydit_app.py --infer-mode fa --lang en --load-key ema --lora-ckpt ./ckpts/t2i/lora/jade
# porcelain style
# Using Flash Attention for acceleration.
python app/hydit_app.py --infer-mode fa --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain
# You can disable the enhancement model if the GPU memory is insufficient.
# The enhancement will be unavailable until you restart the app without the `--no-enhance` flag.
python app/hydit_app.py --infer-mode fa --no-enhance --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain
# Start with English UI
python app/hydit_app.py --infer-mode fa --lang en --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain
```
### Using Command Line
We provide several commands to quick start:
```shell
# jade style
# Prompt Enhancement + Text-to-Image. Torch mode
python sample_t2i.py --infer-mode fa --prompt "玉石绘画风格,一只猫在追蝴蝶" --load-key ema --lora-ckpt ./ckpts/t2i/lora/jade
# Only Text-to-Image. Torch mode
python sample_t2i.py --infer-mode fa --prompt "玉石绘画风格,一只猫在追蝴蝶" --no-enhance --load-key ema --lora-ckpt ./ckpts/t2i/lora/jade
# Generate an image with other image sizes.
python sample_t2i.py --infer-mode fa --prompt "玉石绘画风格,一只猫在追蝴蝶" --image-size 1280 768 --load-key ema --lora-ckpt ./ckpts/t2i/lora/jade
# porcelain style
# Prompt Enhancement + Text-to-Image. Torch mode
python sample_t2i.py --infer-mode fa --prompt "青花瓷风格,一只猫在追蝴蝶" --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain
# Only Text-to-Image. Torch mode
python sample_t2i.py --infer-mode fa --prompt "青花瓷风格,一只猫在追蝴蝶" --no-enhance --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain
# Generate an image with other image sizes.
python sample_t2i.py --infer-mode fa --prompt "青花瓷风格,一只猫在追蝴蝶" --image-size 1280 768 --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain
```
Regarding how to use the LoRA weights we trained in diffusion, we provide the following script. To ensure compatibility with the diffuser, some modifications are made, which means that LoRA cannot be directly loaded.
```python
import torch
from diffusers import HunyuanDiTPipeline
num_layers = 40
def load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale):
for i in range(num_layers):
Wqkv = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_A.weight"])
q, k, v = torch.chunk(Wqkv, 3, dim=0)
transformer_state_dict[f"blocks.{i}.attn1.to_q.weight"] += lora_scale * q
transformer_state_dict[f"blocks.{i}.attn1.to_k.weight"] += lora_scale * k
transformer_state_dict[f"blocks.{i}.attn1.to_v.weight"] += lora_scale * v
out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_A.weight"])
transformer_state_dict[f"blocks.{i}.attn1.to_out.0.weight"] += lora_scale * out_proj
q_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_A.weight"])
transformer_state_dict[f"blocks.{i}.attn2.to_q.weight"] += lora_scale * q_proj
kv_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_A.weight"])
k, v = torch.chunk(kv_proj, 2, dim=0)
transformer_state_dict[f"blocks.{i}.attn2.to_k.weight"] += lora_scale * k
transformer_state_dict[f"blocks.{i}.attn2.to_v.weight"] += lora_scale * v
out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_A.weight"])
transformer_state_dict[f"blocks.{i}.attn2.to_out.0.weight"] += lora_scale * out_proj
q_proj = torch.matmul(lora_state_dict["pooler.q_proj.lora_B.weight"], lora_state_dict["pooler.q_proj.lora_A.weight"])
transformer_state_dict["time_extra_emb.pooler.q_proj.weight"] += lora_scale * q_proj
return transformer_state_dict
pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", torch_dtype=torch.float16)
pipe.to("cuda")
from safetensors import safe_open
lora_state_dict = {}
with safe_open("./ckpts/t2i/lora/jade/adapter_model.safetensors", framework="pt", device=0) as f:
for k in f.keys():
lora_state_dict[k[17:]] = f.get_tensor(k) # remove 'basemodel.model'
transformer_state_dict = pipe.transformer.state_dict()
transformer_state_dict = load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale=1.0)
pipe.transformer.load_state_dict(transformer_state_dict)
prompt = "玉石绘画风格,一只猫在追蝴蝶"
image = pipe(
prompt,
num_inference_steps=100,
guidance_scale=6.0,
).images[0]
image.save('img.png')
```
More example prompts can be found in [example_prompts.txt](example_prompts.txt)
|