|
--- |
|
license: openrail++ |
|
library_name: diffusers |
|
tags: |
|
- text-to-image |
|
- text-to-image |
|
- diffusers-training |
|
- diffusers |
|
- stable-diffusion-xl |
|
- stable-diffusion-xl-diffusers |
|
base_model: stabilityai/stable-diffusion-xl-base-1.0 |
|
--- |
|
|
|
# Margin-aware Preference Optimization for Aligning Diffusion Models without Reference |
|
|
|
<div align="center"> |
|
<img src="https://github.com/mapo-t2i/mapo/blob/main/assets/mapo_overview.png?raw=true" width=750/> |
|
</div><br> |
|
|
|
We propose **MaPO**, a reference-free, sample-efficient, memory-friendly alignment technique for text-to-image diffusion models. For more details on the technique, please refer to our paper [here] (TODO). |
|
|
|
|
|
## Developed by |
|
|
|
* Jiwoo Hong<sup>*</sup> (KAIST AI) |
|
* Sayak Paul<sup>*</sup> (Hugging Face) |
|
* Noah Lee (KAIST AI) |
|
* Kashif Rasul (Hugging Face) |
|
* James Thorne (KAIST AI) |
|
* Jongheon Jeong (Korea University) |
|
|
|
## Dataset |
|
|
|
This model was fine-tuned from [Stable Diffusion XL](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0) on the [yuvalkirstain/pickapic_v2](mhttps://huggingface.co./datasets/yuvalkirstain/pickapic_v2) dataset. |
|
|
|
## Training Code |
|
|
|
Refer to our code repository [here](https://github.com/mapo-t2i/mapo). |
|
|
|
## Results |
|
|
|
Below we report some quantitative metrics and use them to compare MaPO to existing models: |
|
|
|
<style> |
|
table { |
|
width: 100%; |
|
border-collapse: collapse; |
|
} |
|
th, td { |
|
border: 1px solid #000; |
|
padding: 8px; |
|
text-align: center; |
|
} |
|
th { |
|
background-color: #808080; |
|
} |
|
.ours { |
|
font-style: italic; |
|
} |
|
</style> |
|
|
|
<table> |
|
<caption>Average score for Aesthetic, HPS v2.1, and PickScore</caption> |
|
<thead> |
|
<tr> |
|
<th></th> |
|
<th>Aesthetic</th> |
|
<th>HPS v2.1</th> |
|
<th>Pickscore</th> |
|
</tr> |
|
</thead> |
|
<tbody> |
|
<tr> |
|
<td>SDXL</td> |
|
<td>6.03</td> |
|
<td>30.0</td> |
|
<td>22.4</td> |
|
</tr> |
|
<tr> |
|
<td>SFT<sub>Chosen</sub></td> |
|
<td>5.95</td> |
|
<td>29.6</td> |
|
<td>22.0</td> |
|
</tr> |
|
<tr> |
|
<td>Diffusion-DPO</td> |
|
<td>6.03</td> |
|
<td>31.1</td> |
|
<td>22.7</td> |
|
</tr> |
|
<tr class="ours"> |
|
<td>MaPO (Ours)</td> |
|
<td>6.17</td> |
|
<td>31.2</td> |
|
<td>22.5</td> |
|
</tr> |
|
</tbody> |
|
</table> |
|
|
|
|
|
We evaluated this checkpoint in the Imgsys public benchmark. MaPO was able to outperform or match 21 out of 25 state-of-the-art text-to-image diffusion models by ranking 7th on the leaderboard at the time of writing, compared to Diffusion-DPO’s 20th place, while also consuming 14.5% less wall-clock training time on adapting Pick-a-Pic v2. We appreciate the imgsys team for helping us get the human preference data. |
|
|
|
<div align="center"> |
|
<img src="https://mapo-t2i.github.io/static/images/imgsys.png" width=750/> |
|
</div> |
|
|
|
|
|
## Inference |
|
|
|
```python |
|
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel |
|
import torch |
|
|
|
sdxl_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
vae_id = "madebyollin/sdxl-vae-fp16-fix" |
|
unet_id = "mapo-t2i/mapo-beta" |
|
|
|
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16) |
|
unet = UNet2DConditionModel.from_pretrained(unet_id, torch_dtype=torch.float16) |
|
pipeline = DiffusionPipeline.from_pretrained(sdxl_id, vae=vae, unet=unet, torch_dtype=torch.float16).to("cuda") |
|
|
|
prompt = "An abstract portrait consisting of bold, flowing brushstrokes against a neutral background." |
|
image = pipeline(prompt=prompt, num_inference_steps=30).images[0] |
|
``` |
|
|
|
For qualitative results, please visit our [project website](https://mapo-t2i.github.io/). |
|
|
|
## Citation |
|
|
|
```bibtex |
|
@misc{todo, |
|
title={Margin-aware Preference Optimization for Aligning Diffusion Models without Reference}, |
|
author={Jiwoo Hong and Sayak Paul and Noah Lee and Kashif Rasuland James Thorne and Jongheon Jeong}, |
|
year={2024}, |
|
eprint={todo}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CV,cs.LG} |
|
} |
|
``` |