Safetensors
vmistral
custom_code
jiang719 commited on
Commit
023f97e
1 Parent(s): f1c7a6e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +89 -12
README.md CHANGED
@@ -21,21 +21,98 @@ We develope WAFFLE, a fine-tuning approach to train multi-modal LLM (MLLM) to ge
21
  - beautifulsoup4 4.12.3
22
  - accelerate 0.30.1
23
 
24
- ## Structure
25
- - `vlm_websight` contains the dataset class file, model class files, and training file for vlm_websight.
26
- - `eval_websight.py` is the inference file
27
- - `dataset.py` is the dataset class file
28
- - WebSight-Test is one of our test dataset
29
-
30
  ## Quick Start
31
- ```bash
32
- cd vlm_websight
33
- # generate HTML/CSS code for UI image --image_path, save the code to --html_path
34
- python quick_start.py --image_path ../WebSight-Test/test-495.png --html_path examples/example-495.html
35
- # render the HTML/CSS code in --html_path, and save the rendered image to --image_path
36
- python render_html.py --html_path examples/example-495.html --image_path examples/example-495.png
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ```
38
 
 
 
 
 
 
 
 
 
 
 
39
  ## Citation
40
  ```
41
  @misc{liang2024wafflemultimodalmodelautomated,
 
21
  - beautifulsoup4 4.12.3
22
  - accelerate 0.30.1
23
 
 
 
 
 
 
 
24
  ## Quick Start
25
+ * Input UI design
26
+
27
+ Find a webpage screenshot, or UI design:
28
+
29
+ ![test-495.png](examples/test-495.png)
30
+
31
+ * Run Waffle_VLM_WebSight
32
+ ```python
33
+ import torch
34
+ from PIL import Image
35
+ from transformers import AutoProcessor, AutoModelForCausalLM
36
+ from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
37
+ from transformers.image_transforms import resize, to_channel_dimension_format
38
+ from utils import TreeBuilder
39
+
40
+
41
+ def convert_to_rgb(image):
42
+ if image.mode == "RGB":
43
+ return image
44
+
45
+ image_rgba = image.convert("RGBA")
46
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
47
+ alpha_composite = Image.alpha_composite(background, image_rgba)
48
+ alpha_composite = alpha_composite.convert("RGB")
49
+ return alpha_composite
50
+
51
+
52
+ def inference_vlm_websight(image_path, html_path):
53
+
54
+ def custom_transform(x):
55
+ x = convert_to_rgb(x)
56
+ x = to_numpy_array(x)
57
+ x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
58
+ x = processor.image_processor.rescale(x, scale=1 / 255)
59
+ x = processor.image_processor.normalize(
60
+ x,
61
+ mean=processor.image_processor.image_mean,
62
+ std=processor.image_processor.image_std
63
+ )
64
+ x = to_channel_dimension_format(x, ChannelDimension.FIRST)
65
+ x = torch.tensor(x)
66
+ return x
67
+
68
+ model_dir = "lt-asset/Waffle_VLM_WebSight"
69
+ processor = AutoProcessor.from_pretrained(model_dir)
70
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda()
71
+
72
+ assert model.config.web_attention_range == 2, "Waffle_VLM_WebSight is trained with hierarchical attention applied to 2 / 8 heads"
73
+ # use 2/8 = 1/4 attention heads for hierarchical attention (as described in paper)
74
+ model.eval()
75
+
76
+ image_seq_len = model.config.perceiver_config.resampler_n_latents
77
+ BOS_TOKEN = processor.tokenizer.bos_token
78
+ BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
79
+
80
+ image = Image.open(image_path)
81
+ inputs = processor.tokenizer(
82
+ f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
83
+ return_tensors="pt",
84
+ add_special_tokens=False,
85
+ )
86
+ inputs["pixel_values"] = processor.image_processor([image], transform=custom_transform).to(dtype=torch.bfloat16)
87
+ inputs_for_generation = {k: v.cuda() for k, v in inputs.items()}
88
+ inputs_for_generation["web_attention_mask"] = None
89
+ inputs_for_generation["html_tree"] = TreeBuilder(processor.tokenizer)
90
+ inputs_for_generation["html_tree"].web_attention_mask = inputs_for_generation["web_attention_mask"]
91
+
92
+ generated_ids = model.generate(
93
+ **inputs_for_generation, bad_words_ids=BAD_WORDS_IDS, max_length=2048,
94
+ num_return_sequences=1, do_sample=False
95
+ )
96
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
97
+
98
+ with open(html_path, 'w') as wp:
99
+ wp.write(generated_text)
100
+
101
+
102
+ if __name__ == '__main__':
103
+ inference_vlm_websight('examples/test-495.png', 'examples/example-495.html')
104
  ```
105
 
106
+ * Waffle_VLM_WebSight generated HTML code
107
+
108
+ [example-495.html](examples/example-495.html)
109
+
110
+ * Rendered Waffle_VLM_WebSight output
111
+
112
+ Render the HTML, or preview the HTML to check the correctness:
113
+
114
+ ![example-495.html](examples/example-495.png)
115
+
116
  ## Citation
117
  ```
118
  @misc{liang2024wafflemultimodalmodelautomated,