azharaslam commited on
Commit
8bd6e88
·
verified ·
1 Parent(s): 4e52d75

Upload 4 files

Browse files
Files changed (4) hide show
  1. README (1).md +12 -0
  2. app.py +264 -264
  3. packages.txt +8 -8
  4. requirements.txt +5 -5
README (1).md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Screenshot to HTML
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.37.2
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Screenshot to HTML/CSS demo.
app.py CHANGED
@@ -1,264 +1,264 @@
1
- import os
2
- import subprocess
3
- import spaces
4
- import torch
5
-
6
- import gradio as gr
7
-
8
- from gradio_client.client import DEFAULT_TEMP_DIR
9
- from playwright.sync_api import sync_playwright
10
- from threading import Thread
11
- from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer
12
- from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
13
- from typing import List
14
- from PIL import Image
15
-
16
- from transformers.image_transforms import resize, to_channel_dimension_format
17
-
18
-
19
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
-
21
- DEVICE = torch.device("cuda")
22
- PROCESSOR = AutoProcessor.from_pretrained(
23
- "HuggingFaceM4/VLM_WebSight_finetuned",
24
- )
25
- MODEL = AutoModelForCausalLM.from_pretrained(
26
- "HuggingFaceM4/VLM_WebSight_finetuned",
27
- trust_remote_code=True,
28
- torch_dtype=torch.bfloat16,
29
- ).to(DEVICE)
30
- if MODEL.config.use_resampler:
31
- image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
32
- else:
33
- image_seq_len = (
34
- MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
35
- ) ** 2
36
- BOS_TOKEN = PROCESSOR.tokenizer.bos_token
37
- BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
38
-
39
-
40
- ## Utils
41
-
42
- def convert_to_rgb(image):
43
- # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
44
- # for transparent images. The call to `alpha_composite` handles this case
45
- if image.mode == "RGB":
46
- return image
47
-
48
- image_rgba = image.convert("RGBA")
49
- background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
50
- alpha_composite = Image.alpha_composite(background, image_rgba)
51
- alpha_composite = alpha_composite.convert("RGB")
52
- return alpha_composite
53
-
54
- # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
55
- # so this is a hack in order to redefine ONLY the transform method
56
- def custom_transform(x):
57
- x = convert_to_rgb(x)
58
- x = to_numpy_array(x)
59
- x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
60
- x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
61
- x = PROCESSOR.image_processor.normalize(
62
- x,
63
- mean=PROCESSOR.image_processor.image_mean,
64
- std=PROCESSOR.image_processor.image_std
65
- )
66
- x = to_channel_dimension_format(x, ChannelDimension.FIRST)
67
- x = torch.tensor(x)
68
- return x
69
-
70
- ## End of Utils
71
-
72
-
73
- IMAGE_GALLERY_PATHS = [
74
- f"example_images/{ex_image}"
75
- for ex_image in os.listdir(f"example_images")
76
- ]
77
-
78
-
79
- def install_playwright():
80
- try:
81
- subprocess.run(["playwright", "install"], check=True)
82
- print("Playwright installation successful.")
83
- except subprocess.CalledProcessError as e:
84
- print(f"Error during Playwright installation: {e}")
85
-
86
- install_playwright()
87
-
88
-
89
- def add_file_gallery(
90
- selected_state: gr.SelectData,
91
- gallery_list: List[str]
92
- ):
93
- return Image.open(gallery_list.root[selected_state.index].image.path)
94
-
95
-
96
- def render_webpage(
97
- html_css_code,
98
- ):
99
- with sync_playwright() as p:
100
- browser = p.chromium.launch(headless=True)
101
- context = browser.new_context(
102
- user_agent=(
103
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0"
104
- " Safari/537.36"
105
- )
106
- )
107
- page = context.new_page()
108
- page.set_content(html_css_code)
109
- page.wait_for_load_state("networkidle")
110
- output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png"
111
- _ = page.screenshot(path=output_path_screenshot, full_page=True)
112
-
113
- context.close()
114
- browser.close()
115
-
116
- return Image.open(output_path_screenshot)
117
-
118
-
119
- @spaces.GPU(duration=180)
120
- def model_inference(
121
- image,
122
- ):
123
- if image is None:
124
- raise ValueError("`image` is None. It should be a PIL image.")
125
-
126
- inputs = PROCESSOR.tokenizer(
127
- f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
128
- return_tensors="pt",
129
- add_special_tokens=False,
130
- )
131
- inputs["pixel_values"] = PROCESSOR.image_processor(
132
- [image],
133
- transform=custom_transform
134
- )
135
- inputs = {
136
- k: v.to(DEVICE)
137
- for k, v in inputs.items()
138
- }
139
-
140
- streamer = TextIteratorStreamer(
141
- PROCESSOR.tokenizer,
142
- skip_prompt=True,
143
- )
144
- generation_kwargs = dict(
145
- inputs,
146
- bad_words_ids=BAD_WORDS_IDS,
147
- max_length=4096,
148
- streamer=streamer,
149
- )
150
- # Regular generation version
151
- # generation_kwargs.pop("streamer")
152
- # generated_ids = MODEL.generate(**generation_kwargs)
153
- # generated_text = PROCESSOR.batch_decode(
154
- # generated_ids,
155
- # skip_special_tokens=True
156
- # )[0]
157
- # rendered_page = render_webpage(generated_text)
158
- # return generated_text, rendered_page
159
- # Token streaming version
160
- thread = Thread(
161
- target=MODEL.generate,
162
- kwargs=generation_kwargs,
163
- )
164
- thread.start()
165
- generated_text = ""
166
- for new_text in streamer:
167
- if "</s>" in new_text:
168
- new_text = new_text.replace("</s>", "")
169
- rendered_image = render_webpage(generated_text)
170
- else:
171
- rendered_image = None
172
- generated_text += new_text
173
- yield generated_text, rendered_image
174
-
175
-
176
- generated_html = gr.Code(
177
- label="Extracted HTML",
178
- elem_id="generated_html",
179
- )
180
- rendered_html = gr.Image(
181
- label="Rendered HTML",
182
- show_download_button=False,
183
- show_share_button=False,
184
- )
185
- # rendered_html = gr.HTML(
186
- # label="Rendered HTML"
187
- # )
188
-
189
-
190
- css = """
191
- .gradio-container{max-width: 1000px!important}
192
- h1{display: flex;align-items: center;justify-content: center;gap: .25em}
193
- *{transition: width 0.5s ease, flex-grow 0.5s ease}
194
- """
195
-
196
-
197
- with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
198
- gr.Markdown(
199
- "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content."
200
- )
201
- with gr.Row(equal_height=True):
202
- with gr.Column(scale=4, min_width=250) as upload_area:
203
- imagebox = gr.Image(
204
- type="pil",
205
- label="Screenshot to extract",
206
- visible=True,
207
- sources=["upload", "clipboard"],
208
- )
209
- with gr.Group():
210
- with gr.Row():
211
- submit_btn = gr.Button(
212
- value="▶️ Submit", visible=True, min_width=120
213
- )
214
- clear_btn = gr.ClearButton(
215
- [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
216
- )
217
- regenerate_btn = gr.Button(
218
- value="🔄 Regenerate", visible=True, min_width=120
219
- )
220
- with gr.Column(scale=4):
221
- rendered_html.render()
222
-
223
- with gr.Row():
224
- generated_html.render()
225
-
226
- with gr.Row():
227
- template_gallery = gr.Gallery(
228
- value=IMAGE_GALLERY_PATHS,
229
- label="Templates Gallery",
230
- allow_preview=False,
231
- columns=5,
232
- elem_id="gallery",
233
- show_share_button=False,
234
- height=400,
235
- )
236
-
237
- gr.on(
238
- triggers=[
239
- imagebox.upload,
240
- submit_btn.click,
241
- regenerate_btn.click,
242
- ],
243
- fn=model_inference,
244
- inputs=[imagebox],
245
- outputs=[generated_html, rendered_html],
246
- )
247
- regenerate_btn.click(
248
- fn=model_inference,
249
- inputs=[imagebox],
250
- outputs=[generated_html, rendered_html],
251
- )
252
- template_gallery.select(
253
- fn=add_file_gallery,
254
- inputs=[template_gallery],
255
- outputs=[imagebox],
256
- ).success(
257
- fn=model_inference,
258
- inputs=[imagebox],
259
- outputs=[generated_html, rendered_html],
260
- )
261
- demo.load()
262
-
263
- demo.queue(max_size=40, api_open=False)
264
- demo.launch(max_threads=400)
 
1
+ import os
2
+ import subprocess
3
+ import spaces
4
+ import torch
5
+
6
+ import gradio as gr
7
+
8
+ from gradio_client.client import DEFAULT_TEMP_DIR
9
+ from playwright.sync_api import sync_playwright
10
+ from threading import Thread
11
+ from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer
12
+ from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
13
+ from typing import List
14
+ from PIL import Image
15
+
16
+ from transformers.image_transforms import resize, to_channel_dimension_format
17
+
18
+
19
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
+
21
+ DEVICE = torch.device("cuda")
22
+ PROCESSOR = AutoProcessor.from_pretrained(
23
+ "HuggingFaceM4/VLM_WebSight_finetuned",
24
+ )
25
+ MODEL = AutoModelForCausalLM.from_pretrained(
26
+ "HuggingFaceM4/VLM_WebSight_finetuned",
27
+ trust_remote_code=True,
28
+ torch_dtype=torch.bfloat16,
29
+ ).to(DEVICE)
30
+ if MODEL.config.use_resampler:
31
+ image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
32
+ else:
33
+ image_seq_len = (
34
+ MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
35
+ ) ** 2
36
+ BOS_TOKEN = PROCESSOR.tokenizer.bos_token
37
+ BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
38
+
39
+
40
+ ## Utils
41
+
42
+ def convert_to_rgb(image):
43
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
44
+ # for transparent images. The call to `alpha_composite` handles this case
45
+ if image.mode == "RGB":
46
+ return image
47
+
48
+ image_rgba = image.convert("RGBA")
49
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
50
+ alpha_composite = Image.alpha_composite(background, image_rgba)
51
+ alpha_composite = alpha_composite.convert("RGB")
52
+ return alpha_composite
53
+
54
+ # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
55
+ # so this is a hack in order to redefine ONLY the transform method
56
+ def custom_transform(x):
57
+ x = convert_to_rgb(x)
58
+ x = to_numpy_array(x)
59
+ x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
60
+ x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
61
+ x = PROCESSOR.image_processor.normalize(
62
+ x,
63
+ mean=PROCESSOR.image_processor.image_mean,
64
+ std=PROCESSOR.image_processor.image_std
65
+ )
66
+ x = to_channel_dimension_format(x, ChannelDimension.FIRST)
67
+ x = torch.tensor(x)
68
+ return x
69
+
70
+ ## End of Utils
71
+
72
+
73
+ IMAGE_GALLERY_PATHS = [
74
+ f"example_images/{ex_image}"
75
+ for ex_image in os.listdir(f"example_images")
76
+ ]
77
+
78
+
79
+ def install_playwright():
80
+ try:
81
+ subprocess.run(["playwright", "install"], check=True)
82
+ print("Playwright installation successful.")
83
+ except subprocess.CalledProcessError as e:
84
+ print(f"Error during Playwright installation: {e}")
85
+
86
+ install_playwright()
87
+
88
+
89
+ def add_file_gallery(
90
+ selected_state: gr.SelectData,
91
+ gallery_list: List[str]
92
+ ):
93
+ return Image.open(gallery_list.root[selected_state.index].image.path)
94
+
95
+
96
+ def render_webpage(
97
+ html_css_code,
98
+ ):
99
+ with sync_playwright() as p:
100
+ browser = p.chromium.launch(headless=True)
101
+ context = browser.new_context(
102
+ user_agent=(
103
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0"
104
+ " Safari/537.36"
105
+ )
106
+ )
107
+ page = context.new_page()
108
+ page.set_content(html_css_code)
109
+ page.wait_for_load_state("networkidle")
110
+ output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png"
111
+ _ = page.screenshot(path=output_path_screenshot, full_page=True)
112
+
113
+ context.close()
114
+ browser.close()
115
+
116
+ return Image.open(output_path_screenshot)
117
+
118
+
119
+ @spaces.GPU(duration=180)
120
+ def model_inference(
121
+ image,
122
+ ):
123
+ if image is None:
124
+ raise ValueError("`image` is None. It should be a PIL image.")
125
+
126
+ inputs = PROCESSOR.tokenizer(
127
+ f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
128
+ return_tensors="pt",
129
+ add_special_tokens=False,
130
+ )
131
+ inputs["pixel_values"] = PROCESSOR.image_processor(
132
+ [image],
133
+ transform=custom_transform
134
+ )
135
+ inputs = {
136
+ k: v.to(DEVICE)
137
+ for k, v in inputs.items()
138
+ }
139
+
140
+ streamer = TextIteratorStreamer(
141
+ PROCESSOR.tokenizer,
142
+ skip_prompt=True,
143
+ )
144
+ generation_kwargs = dict(
145
+ inputs,
146
+ bad_words_ids=BAD_WORDS_IDS,
147
+ max_length=4096,
148
+ streamer=streamer,
149
+ )
150
+ # Regular generation version
151
+ # generation_kwargs.pop("streamer")
152
+ # generated_ids = MODEL.generate(**generation_kwargs)
153
+ # generated_text = PROCESSOR.batch_decode(
154
+ # generated_ids,
155
+ # skip_special_tokens=True
156
+ # )[0]
157
+ # rendered_page = render_webpage(generated_text)
158
+ # return generated_text, rendered_page
159
+ # Token streaming version
160
+ thread = Thread(
161
+ target=MODEL.generate,
162
+ kwargs=generation_kwargs,
163
+ )
164
+ thread.start()
165
+ generated_text = ""
166
+ for new_text in streamer:
167
+ if "</s>" in new_text:
168
+ new_text = new_text.replace("</s>", "")
169
+ rendered_image = render_webpage(generated_text)
170
+ else:
171
+ rendered_image = None
172
+ generated_text += new_text
173
+ yield generated_text, rendered_image
174
+
175
+
176
+ generated_html = gr.Code(
177
+ label="Extracted HTML",
178
+ elem_id="generated_html",
179
+ )
180
+ rendered_html = gr.Image(
181
+ label="Rendered HTML",
182
+ show_download_button=False,
183
+ show_share_button=False,
184
+ )
185
+ # rendered_html = gr.HTML(
186
+ # label="Rendered HTML"
187
+ # )
188
+
189
+
190
+ css = """
191
+ .gradio-container{max-width: 1000px!important}
192
+ h1{display: flex;align-items: center;justify-content: center;gap: .25em}
193
+ *{transition: width 0.5s ease, flex-grow 0.5s ease}
194
+ """
195
+
196
+
197
+ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
198
+ gr.Markdown(
199
+ "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content."
200
+ )
201
+ with gr.Row(equal_height=True):
202
+ with gr.Column(scale=4, min_width=250) as upload_area:
203
+ imagebox = gr.Image(
204
+ type="pil",
205
+ label="Screenshot to extract",
206
+ visible=True,
207
+ sources=["upload", "clipboard"],
208
+ )
209
+ with gr.Group():
210
+ with gr.Row():
211
+ submit_btn = gr.Button(
212
+ value="▶️ Submit", visible=True, min_width=120
213
+ )
214
+ clear_btn = gr.ClearButton(
215
+ [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
216
+ )
217
+ regenerate_btn = gr.Button(
218
+ value="🔄 Regenerate", visible=True, min_width=120
219
+ )
220
+ with gr.Column(scale=4):
221
+ rendered_html.render()
222
+
223
+ with gr.Row():
224
+ generated_html.render()
225
+
226
+ with gr.Row():
227
+ template_gallery = gr.Gallery(
228
+ value=IMAGE_GALLERY_PATHS,
229
+ label="Templates Gallery",
230
+ allow_preview=False,
231
+ columns=5,
232
+ elem_id="gallery",
233
+ show_share_button=False,
234
+ height=400,
235
+ )
236
+
237
+ gr.on(
238
+ triggers=[
239
+ imagebox.upload,
240
+ submit_btn.click,
241
+ regenerate_btn.click,
242
+ ],
243
+ fn=model_inference,
244
+ inputs=[imagebox],
245
+ outputs=[generated_html, rendered_html],
246
+ )
247
+ regenerate_btn.click(
248
+ fn=model_inference,
249
+ inputs=[imagebox],
250
+ outputs=[generated_html, rendered_html],
251
+ )
252
+ template_gallery.select(
253
+ fn=add_file_gallery,
254
+ inputs=[template_gallery],
255
+ outputs=[imagebox],
256
+ ).success(
257
+ fn=model_inference,
258
+ inputs=[imagebox],
259
+ outputs=[generated_html, rendered_html],
260
+ )
261
+ demo.load()
262
+
263
+ demo.queue(max_size=40, api_open=False)
264
+ demo.launch(max_threads=400)
packages.txt CHANGED
@@ -1,8 +1,8 @@
1
- libnss3
2
- libnspr4
3
- libatk1.0-0
4
- libatk-bridge2.0-0
5
- libcups2
6
- libatspi2.0-0
7
- libxcomposite1
8
- libxdamage1
 
1
+ libnss3
2
+ libnspr4
3
+ libatk1.0-0
4
+ libatk-bridge2.0-0
5
+ libcups2
6
+ libatspi2.0-0
7
+ libxcomposite1
8
+ libxdamage1
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- playwright
2
- transformers
3
- packaging
4
- ninja
5
- spaces
6
  torch
 
1
+ playwright
2
+ transformers
3
+ packaging
4
+ ninja
5
+ spaces
6
  torch