Update README.md
Browse files
README.md
CHANGED
@@ -21,21 +21,98 @@ We develope WAFFLE, a fine-tuning approach to train multi-modal LLM (MLLM) to ge
|
|
21 |
- beautifulsoup4 4.12.3
|
22 |
- accelerate 0.30.1
|
23 |
|
24 |
-
## Structure
|
25 |
-
- `vlm_websight` contains the dataset class file, model class files, and training file for vlm_websight.
|
26 |
-
- `eval_websight.py` is the inference file
|
27 |
-
- `dataset.py` is the dataset class file
|
28 |
-
- WebSight-Test is one of our test dataset
|
29 |
-
|
30 |
## Quick Start
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
```
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
## Citation
|
40 |
```
|
41 |
@misc{liang2024wafflemultimodalmodelautomated,
|
|
|
21 |
- beautifulsoup4 4.12.3
|
22 |
- accelerate 0.30.1
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
## Quick Start
|
25 |
+
* Input UI design
|
26 |
+
|
27 |
+
Find a webpage screenshot, or UI design:
|
28 |
+
|
29 |
+
![test-495.png](examples/test-495.png)
|
30 |
+
|
31 |
+
* Run Waffle_VLM_WebSight
|
32 |
+
```python
|
33 |
+
import torch
|
34 |
+
from PIL import Image
|
35 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
36 |
+
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
|
37 |
+
from transformers.image_transforms import resize, to_channel_dimension_format
|
38 |
+
from utils import TreeBuilder
|
39 |
+
|
40 |
+
|
41 |
+
def convert_to_rgb(image):
|
42 |
+
if image.mode == "RGB":
|
43 |
+
return image
|
44 |
+
|
45 |
+
image_rgba = image.convert("RGBA")
|
46 |
+
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
|
47 |
+
alpha_composite = Image.alpha_composite(background, image_rgba)
|
48 |
+
alpha_composite = alpha_composite.convert("RGB")
|
49 |
+
return alpha_composite
|
50 |
+
|
51 |
+
|
52 |
+
def inference_vlm_websight(image_path, html_path):
|
53 |
+
|
54 |
+
def custom_transform(x):
|
55 |
+
x = convert_to_rgb(x)
|
56 |
+
x = to_numpy_array(x)
|
57 |
+
x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
|
58 |
+
x = processor.image_processor.rescale(x, scale=1 / 255)
|
59 |
+
x = processor.image_processor.normalize(
|
60 |
+
x,
|
61 |
+
mean=processor.image_processor.image_mean,
|
62 |
+
std=processor.image_processor.image_std
|
63 |
+
)
|
64 |
+
x = to_channel_dimension_format(x, ChannelDimension.FIRST)
|
65 |
+
x = torch.tensor(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
model_dir = "lt-asset/Waffle_VLM_WebSight"
|
69 |
+
processor = AutoProcessor.from_pretrained(model_dir)
|
70 |
+
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda()
|
71 |
+
|
72 |
+
assert model.config.web_attention_range == 2, "Waffle_VLM_WebSight is trained with hierarchical attention applied to 2 / 8 heads"
|
73 |
+
# use 2/8 = 1/4 attention heads for hierarchical attention (as described in paper)
|
74 |
+
model.eval()
|
75 |
+
|
76 |
+
image_seq_len = model.config.perceiver_config.resampler_n_latents
|
77 |
+
BOS_TOKEN = processor.tokenizer.bos_token
|
78 |
+
BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
|
79 |
+
|
80 |
+
image = Image.open(image_path)
|
81 |
+
inputs = processor.tokenizer(
|
82 |
+
f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
|
83 |
+
return_tensors="pt",
|
84 |
+
add_special_tokens=False,
|
85 |
+
)
|
86 |
+
inputs["pixel_values"] = processor.image_processor([image], transform=custom_transform).to(dtype=torch.bfloat16)
|
87 |
+
inputs_for_generation = {k: v.cuda() for k, v in inputs.items()}
|
88 |
+
inputs_for_generation["web_attention_mask"] = None
|
89 |
+
inputs_for_generation["html_tree"] = TreeBuilder(processor.tokenizer)
|
90 |
+
inputs_for_generation["html_tree"].web_attention_mask = inputs_for_generation["web_attention_mask"]
|
91 |
+
|
92 |
+
generated_ids = model.generate(
|
93 |
+
**inputs_for_generation, bad_words_ids=BAD_WORDS_IDS, max_length=2048,
|
94 |
+
num_return_sequences=1, do_sample=False
|
95 |
+
)
|
96 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
97 |
+
|
98 |
+
with open(html_path, 'w') as wp:
|
99 |
+
wp.write(generated_text)
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == '__main__':
|
103 |
+
inference_vlm_websight('examples/test-495.png', 'examples/example-495.html')
|
104 |
```
|
105 |
|
106 |
+
* Waffle_VLM_WebSight generated HTML code
|
107 |
+
|
108 |
+
[example-495.html](examples/example-495.html)
|
109 |
+
|
110 |
+
* Rendered Waffle_VLM_WebSight output
|
111 |
+
|
112 |
+
Render the HTML, or preview the HTML to check the correctness:
|
113 |
+
|
114 |
+
![example-495.html](examples/example-495.png)
|
115 |
+
|
116 |
## Citation
|
117 |
```
|
118 |
@misc{liang2024wafflemultimodalmodelautomated,
|