Safetensors
vmistral
custom_code
File size: 5,084 Bytes
e745606
 
 
 
 
 
 
 
51ef0f4
e745606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
023f97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e745606
 
023f97e
 
 
 
 
 
 
 
 
 
e745606
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
---
license: bsd-3-clause-clear
---

# WAFFLE: Multi-Modal Model for Automated Front-End Development
We develope WAFFLE, a fine-tuning approach to train multi-modal LLM (MLLM) to generate HTML code from webpage screenshots or UI designs. WAFFLE uses a structure-aware attention mechanism to improve MLLMs' understanding of HTML's structure and a contrastive fine-tuning approach to align MLLMs' understanding of UI images and HTML code. Models fine-tuned with WAFFLE show up to 9.00 pp (percentage point) higher HTML match, 0.0982 higher CW-SSIM, 32.99 higher CLIP, and 27.12 pp higher LLEM on our new benchmark WebSight-Test and an existing benchmark Design2Code.

## Updates:
* 10/24/2024: Our preprint avaiable at: [arXiv](https://arxiv.org/abs/2410.18362), [huggingface](https://huggingface.co./papers/2410.18362)
* 10/24/2024: Our code (keep maintaining) avaiable at: [code](https://github.com/lt-asset/Waffle)
* 10/24/2024: Our fine-tuned Waffle_VLM_WebSight (7B), using DoRA, is released at: [lt-asset/Waffle_VLM_WebSight](https://huggingface.co./lt-asset/Waffle_VLM_WebSight)

## Dependency
- peft               0.11.1
- transformers       4.41.1
- pytorch       2.3.0
- selenium
- Python 3.10.14
- deepspeed          0.14.1
- datasets 2.19.1
- beautifulsoup4     4.12.3
- accelerate         0.30.1

## Quick Start
* Input UI design

Find a webpage screenshot, or UI design:

![test-495.png](examples/test-495.png)

* Run Waffle_VLM_WebSight
```python
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
from utils import TreeBuilder


def convert_to_rgb(image):
    if image.mode == "RGB":
        return image

    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    alpha_composite = alpha_composite.convert("RGB")
    return alpha_composite


def inference_vlm_websight(image_path, html_path):
    
    def custom_transform(x):
        x = convert_to_rgb(x)
        x = to_numpy_array(x)
        x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
        x = processor.image_processor.rescale(x, scale=1 / 255)
        x = processor.image_processor.normalize(
            x,
            mean=processor.image_processor.image_mean,
            std=processor.image_processor.image_std
        )
        x = to_channel_dimension_format(x, ChannelDimension.FIRST)
        x = torch.tensor(x)
        return x

    model_dir = "lt-asset/Waffle_VLM_WebSight"
    processor = AutoProcessor.from_pretrained(model_dir)
    model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda()
    
    assert model.config.web_attention_range == 2, "Waffle_VLM_WebSight is trained with hierarchical attention applied to 2 / 8 heads"
    # use 2/8 = 1/4 attention heads for hierarchical attention (as described in paper)
    model.eval()

    image_seq_len = model.config.perceiver_config.resampler_n_latents
    BOS_TOKEN = processor.tokenizer.bos_token
    BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids

    image = Image.open(image_path)
    inputs = processor.tokenizer(
        f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
        return_tensors="pt",
        add_special_tokens=False,
    )
    inputs["pixel_values"] = processor.image_processor([image], transform=custom_transform).to(dtype=torch.bfloat16)
    inputs_for_generation = {k: v.cuda() for k, v in inputs.items()}
    inputs_for_generation["web_attention_mask"] = None
    inputs_for_generation["html_tree"] = TreeBuilder(processor.tokenizer)
    inputs_for_generation["html_tree"].web_attention_mask = inputs_for_generation["web_attention_mask"]

    generated_ids = model.generate(
        **inputs_for_generation, bad_words_ids=BAD_WORDS_IDS, max_length=2048, 
        num_return_sequences=1, do_sample=False
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    with open(html_path, 'w') as wp:
        wp.write(generated_text)


if __name__ == '__main__':
    inference_vlm_websight('examples/test-495.png', 'examples/example-495.html')
```

* Waffle_VLM_WebSight generated HTML code

[example-495.html](examples/example-495.html)

* Rendered Waffle_VLM_WebSight output

Render the HTML, or preview the HTML to check the correctness:

![example-495.html](examples/example-495.png)

## Citation
```
@misc{liang2024wafflemultimodalmodelautomated,
      title={WAFFLE: Multi-Modal Model for Automated Front-End Development}, 
      author={Shanchao Liang and Nan Jiang and Shangshu Qian and Lin Tan},
      year={2024},
      eprint={2410.18362},
      archivePrefix={arXiv},
      primaryClass={cs.SE},
      url={https://arxiv.org/abs/2410.18362}, 
}
```