abdullahmeda
commited on
Commit
·
a4d40bc
1
Parent(s):
b9309ba
- README.md +31 -6
- app.py +109 -0
- examples/tt0068646-the-godfather.jpg +0 -0
- examples/tt0076759-star-wars.jpg +0 -0
- examples/tt0108778-friends.jpg +0 -0
- examples/tt0109830-forrest-gump.jpg +0 -0
- examples/tt0434409-v-for-vendetta.jpg +0 -0
- examples/tt10062292-never-have-i-ever.jpg +0 -0
- examples/tt10919420-squid-games.jpg +0 -0
- examples/tt3521164-moana.jpg +0 -0
- examples/tt6468322-money-heist.jpg +0 -0
- examples/tt7991608-red-notice.jpg +0 -0
- examples/tt8366590-baaghi3.jpg +0 -0
- flagged/image/0.jpg +0 -0
- flagged/image/1.jpg +0 -0
- flagged/log.csv +3 -0
- requirements.txt +4 -0
- test.py +94 -0
- train/README.md +9 -0
- train/create_dataset.ipynb +326 -0
- train/requirements.txt +7 -0
- train/train.ipynb +474 -0
README.md
CHANGED
@@ -1,12 +1,37 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.0.20
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Poster2plot
|
3 |
+
emoji: 🎬
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
|
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
+
# Configuration
|
12 |
+
|
13 |
+
`title`: _string_
|
14 |
+
Display title for the Space
|
15 |
+
|
16 |
+
`emoji`: _string_
|
17 |
+
Space emoji (emoji-only character allowed)
|
18 |
+
|
19 |
+
`colorFrom`: _string_
|
20 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
21 |
+
|
22 |
+
`colorTo`: _string_
|
23 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
24 |
+
|
25 |
+
`sdk`: _string_
|
26 |
+
Can be either `gradio` or `streamlit`
|
27 |
+
|
28 |
+
`sdk_version` : _string_
|
29 |
+
Only applicable for `streamlit` SDK.
|
30 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
31 |
+
|
32 |
+
`app_file`: _string_
|
33 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
34 |
+
Path is relative to the root of the repository.
|
35 |
+
|
36 |
+
`pinned`: _boolean_
|
37 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import re
|
4 |
+
import gradio as gr
|
5 |
+
from pathlib import Path
|
6 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
|
7 |
+
|
8 |
+
|
9 |
+
# Pattern to ignore all the text after 2 or more full stops
|
10 |
+
regex_pattern = "[.]{2,}"
|
11 |
+
|
12 |
+
|
13 |
+
def post_process(text):
|
14 |
+
try:
|
15 |
+
text = text.strip()
|
16 |
+
text = re.split(regex_pattern, text)[0]
|
17 |
+
except Exception as e:
|
18 |
+
print(e)
|
19 |
+
pass
|
20 |
+
return text
|
21 |
+
|
22 |
+
|
23 |
+
def set_example_image(example: list) -> dict:
|
24 |
+
return gr.Image.update(value=example[0])
|
25 |
+
|
26 |
+
|
27 |
+
def predict(image, max_length=64, num_beams=4):
|
28 |
+
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
29 |
+
pixel_values = pixel_values.to(device)
|
30 |
+
|
31 |
+
with torch.no_grad():
|
32 |
+
output_ids = model.generate(
|
33 |
+
pixel_values,
|
34 |
+
max_length=max_length,
|
35 |
+
num_beams=num_beams,
|
36 |
+
return_dict_in_generate=True,
|
37 |
+
).sequences
|
38 |
+
|
39 |
+
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
40 |
+
pred = post_process(preds[0])
|
41 |
+
|
42 |
+
return pred
|
43 |
+
|
44 |
+
|
45 |
+
model_name_or_path = "deepklarity/poster2plot"
|
46 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
47 |
+
|
48 |
+
# Load model.
|
49 |
+
|
50 |
+
model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
|
51 |
+
model.to(device)
|
52 |
+
print("Loaded model")
|
53 |
+
|
54 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
|
55 |
+
print("Loaded feature_extractor")
|
56 |
+
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
|
58 |
+
if model.decoder.name_or_path == "gpt2":
|
59 |
+
tokenizer.pad_token = tokenizer.eos_token
|
60 |
+
|
61 |
+
print("Loaded tokenizer")
|
62 |
+
|
63 |
+
title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot"
|
64 |
+
description = ""
|
65 |
+
|
66 |
+
input = gr.inputs.Image(type="pil")
|
67 |
+
|
68 |
+
example_images = sorted(
|
69 |
+
[f.as_posix() for f in Path("examples").glob("*.jpg")]
|
70 |
+
)
|
71 |
+
print(f"Loaded {len(example_images)} example images")
|
72 |
+
|
73 |
+
demo = gr.Blocks()
|
74 |
+
filenames = next(os.walk('examples'), (None, None, []))[2]
|
75 |
+
examples = [[f"examples/{filename}"] for filename in filenames]
|
76 |
+
print(examples)
|
77 |
+
|
78 |
+
with demo:
|
79 |
+
with gr.Column():
|
80 |
+
with gr.Row():
|
81 |
+
with gr.Column():
|
82 |
+
input_image = gr.Image()
|
83 |
+
with gr.Row():
|
84 |
+
clear_button = gr.Button(value="Clear", variant='secondary')
|
85 |
+
submit_button = gr.Button(value="Submit", variant='primary')
|
86 |
+
with gr.Column():
|
87 |
+
plot = gr.Textbox()
|
88 |
+
with gr.Row():
|
89 |
+
example_images = gr.Dataset(components=[input_image], samples=examples)
|
90 |
+
|
91 |
+
submit_button.click(fn=predict, inputs=[input_image], outputs=[plot])
|
92 |
+
example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components)
|
93 |
+
|
94 |
+
demo.launch()
|
95 |
+
|
96 |
+
|
97 |
+
interface = gr.Interface(
|
98 |
+
fn=predict,
|
99 |
+
inputs=input,
|
100 |
+
outputs="textbox",
|
101 |
+
title=title,
|
102 |
+
description=description,
|
103 |
+
examples=example_images,
|
104 |
+
examples_per_page=20,
|
105 |
+
live=True,
|
106 |
+
article='<p>Made by: <a href="https://twitter.com/kartik_godawat" target="_blank" rel="noopener noreferrer">dk-crazydiv</a> and <a href="https://twitter.com/dsr_ai" target="_blank" rel="noopener noreferrer">dsr</a></p>'
|
107 |
+
)
|
108 |
+
|
109 |
+
interface.launch()
|
examples/tt0068646-the-godfather.jpg
ADDED
examples/tt0076759-star-wars.jpg
ADDED
examples/tt0108778-friends.jpg
ADDED
examples/tt0109830-forrest-gump.jpg
ADDED
examples/tt0434409-v-for-vendetta.jpg
ADDED
examples/tt10062292-never-have-i-ever.jpg
ADDED
examples/tt10919420-squid-games.jpg
ADDED
examples/tt3521164-moana.jpg
ADDED
examples/tt6468322-money-heist.jpg
ADDED
examples/tt7991608-red-notice.jpg
ADDED
examples/tt8366590-baaghi3.jpg
ADDED
flagged/image/0.jpg
ADDED
flagged/image/1.jpg
ADDED
flagged/log.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
'image','output','flag','username','timestamp'
|
2 |
+
'image/0.jpg','A young woman is forced to deal with her past when she is accused of murder. She tries to find out what happened to her husband, who is also accused of the crime. Will she be able to solve the case or will she be the one to save her husband''s life? Based on the true story of','','','2022-06-23 18:30:55.658016'
|
3 |
+
'image/1.jpg','A young woman is forced to deal with her past when she is accused of murder. She tries to find out what happened to her husband, who is also accused of the crime. Will she be able to solve the case or will she be the one to save her husband''s life? Based on the true story of','','','2022-06-23 18:30:57.352462'
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--find-links https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
gradio==2.9.0
|
3 |
+
transformers==4.12.5
|
4 |
+
torch==1.10.0+cpu
|
test.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import re
|
4 |
+
import gradio as gr
|
5 |
+
from pathlib import Path
|
6 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
|
7 |
+
|
8 |
+
|
9 |
+
# Pattern to ignore all the text after 2 or more full stops
|
10 |
+
regex_pattern = "[.]{2,}"
|
11 |
+
|
12 |
+
|
13 |
+
def post_process(text):
|
14 |
+
try:
|
15 |
+
text = text.strip()
|
16 |
+
text = re.split(regex_pattern, text)[0]
|
17 |
+
except Exception as e:
|
18 |
+
print(e)
|
19 |
+
pass
|
20 |
+
return text
|
21 |
+
|
22 |
+
|
23 |
+
def set_example_image(example: list) -> dict:
|
24 |
+
return gr.Image.update(value=example[0])
|
25 |
+
|
26 |
+
|
27 |
+
def predict(image, max_length=64, num_beams=4):
|
28 |
+
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
29 |
+
pixel_values = pixel_values.to(device)
|
30 |
+
|
31 |
+
with torch.no_grad():
|
32 |
+
output_ids = model.generate(
|
33 |
+
pixel_values,
|
34 |
+
max_length=max_length,
|
35 |
+
num_beams=num_beams,
|
36 |
+
return_dict_in_generate=True,
|
37 |
+
).sequences
|
38 |
+
|
39 |
+
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
40 |
+
pred = post_process(preds[0])
|
41 |
+
|
42 |
+
return pred
|
43 |
+
|
44 |
+
|
45 |
+
model_name_or_path = "deepklarity/poster2plot"
|
46 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
47 |
+
|
48 |
+
# Load model.
|
49 |
+
|
50 |
+
model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
|
51 |
+
model.to(device)
|
52 |
+
print("Loaded model")
|
53 |
+
|
54 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
|
55 |
+
print("Loaded feature_extractor")
|
56 |
+
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
|
58 |
+
if model.decoder.name_or_path == "gpt2":
|
59 |
+
tokenizer.pad_token = tokenizer.eos_token
|
60 |
+
|
61 |
+
print("Loaded tokenizer")
|
62 |
+
|
63 |
+
title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot"
|
64 |
+
description = ""
|
65 |
+
|
66 |
+
input = gr.inputs.Image(type="pil")
|
67 |
+
|
68 |
+
example_images = sorted(
|
69 |
+
[f.as_posix() for f in Path("examples").glob("*.jpg")]
|
70 |
+
)
|
71 |
+
print(f"Loaded {len(example_images)} example images")
|
72 |
+
|
73 |
+
demo = gr.Blocks()
|
74 |
+
filenames = next(os.walk('examples'), (None, None, []))[2]
|
75 |
+
examples = [[f"examples/{filename}"] for filename in filenames]
|
76 |
+
print(examples)
|
77 |
+
|
78 |
+
with demo:
|
79 |
+
with gr.Column():
|
80 |
+
with gr.Row():
|
81 |
+
with gr.Column():
|
82 |
+
input_image = gr.Image()
|
83 |
+
with gr.Row():
|
84 |
+
clear_button = gr.Button(value="Clear", variant='secondary')
|
85 |
+
submit_button = gr.Button(value="Submit", variant='primary')
|
86 |
+
with gr.Column():
|
87 |
+
plot = gr.Textbox()
|
88 |
+
with gr.Row():
|
89 |
+
example_images = gr.Dataset(components=[input_image], samples=examples)
|
90 |
+
|
91 |
+
submit_button.click(fn=predict, inputs=[input_image], outputs=[plot])
|
92 |
+
example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components)
|
93 |
+
|
94 |
+
demo.launch()
|
train/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Train new model
|
2 |
+
|
3 |
+
- Download and extract the following datasets in a new folder called datasets:
|
4 |
+
|
5 |
+
1. [IMDb movies extensive dataset](https://www.kaggle.com/stefanoleone992/imdb-extensive-dataset)
|
6 |
+
2. [48K IMDB Movies With Posters](https://www.kaggle.com/rezaunderfit/48k-imdb-movies-with-posters)
|
7 |
+
|
8 |
+
- Run `create_dataset.ipynb` to create train.csv and valid.csv
|
9 |
+
- Run `train.ipynb` to train the model
|
train/create_dataset.ipynb
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "0fbed7bc",
|
7 |
+
"metadata": {
|
8 |
+
"ExecuteTime": {
|
9 |
+
"end_time": "2021-12-09T16:46:29.851016Z",
|
10 |
+
"start_time": "2021-12-09T16:46:29.841794Z"
|
11 |
+
},
|
12 |
+
"pycharm": {
|
13 |
+
"name": "#%%\n"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"%reload_ext autoreload\n",
|
19 |
+
"%autoreload 2"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
+
"id": "99d6f14d",
|
26 |
+
"metadata": {
|
27 |
+
"ExecuteTime": {
|
28 |
+
"end_time": "2021-12-09T16:46:30.336104Z",
|
29 |
+
"start_time": "2021-12-09T16:46:29.852308Z"
|
30 |
+
},
|
31 |
+
"pycharm": {
|
32 |
+
"name": "#%%\n"
|
33 |
+
}
|
34 |
+
},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"from pathlib import Path\n",
|
38 |
+
"import pandas as pd\n",
|
39 |
+
"import shutil\n",
|
40 |
+
"from sklearn.model_selection import train_test_split"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"id": "c8fcf96c",
|
47 |
+
"metadata": {
|
48 |
+
"ExecuteTime": {
|
49 |
+
"end_time": "2021-12-09T16:46:30.349125Z",
|
50 |
+
"start_time": "2021-12-09T16:46:30.337223Z"
|
51 |
+
},
|
52 |
+
"code_folding": [],
|
53 |
+
"pycharm": {
|
54 |
+
"name": "#%%\n"
|
55 |
+
}
|
56 |
+
},
|
57 |
+
"outputs": [],
|
58 |
+
"source": [
|
59 |
+
"def copy_images(\n",
|
60 |
+
" src_dir: Path,\n",
|
61 |
+
" des_dir: Path,\n",
|
62 |
+
" ids_with_plots: list,\n",
|
63 |
+
" delete_existing_files: bool = False,\n",
|
64 |
+
"):\n",
|
65 |
+
" \"\"\"This function copies a poster to images folder if it's id is present in the ids_with_plots list\"\"\"\n",
|
66 |
+
"\n",
|
67 |
+
" images_list = []\n",
|
68 |
+
" if delete_existing_files:\n",
|
69 |
+
" shutil.rmtree(des_dir)\n",
|
70 |
+
"\n",
|
71 |
+
" des_dir.mkdir(parents=True, exist_ok=True)\n",
|
72 |
+
"\n",
|
73 |
+
" for f in src_dir.rglob(\"*\"):\n",
|
74 |
+
" try:\n",
|
75 |
+
" if f.is_file() and f.suffix in [\".jpg\", \".jpeg\", \".png\"]:\n",
|
76 |
+
" img_name = f.name\n",
|
77 |
+
" id = Path(img_name).stem\n",
|
78 |
+
" if id in ids_with_plots:\n",
|
79 |
+
" desc_file = des_dir / img_name\n",
|
80 |
+
" shutil.copy(f, desc_file)\n",
|
81 |
+
" images_list.append((id, img_name))\n",
|
82 |
+
" except Exception as e:\n",
|
83 |
+
" print(f, e)\n",
|
84 |
+
" return images_list"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": null,
|
90 |
+
"id": "a34124b2",
|
91 |
+
"metadata": {
|
92 |
+
"ExecuteTime": {
|
93 |
+
"end_time": "2021-12-09T16:46:30.359361Z",
|
94 |
+
"start_time": "2021-12-09T16:46:30.350299Z"
|
95 |
+
},
|
96 |
+
"pycharm": {
|
97 |
+
"name": "#%%\n"
|
98 |
+
}
|
99 |
+
},
|
100 |
+
"outputs": [],
|
101 |
+
"source": [
|
102 |
+
"data_dir = Path(\"datasets\").resolve()\n",
|
103 |
+
"images_dir = data_dir / \"images\""
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": null,
|
109 |
+
"id": "8714ea01",
|
110 |
+
"metadata": {
|
111 |
+
"ExecuteTime": {
|
112 |
+
"end_time": "2021-12-09T16:46:30.781046Z",
|
113 |
+
"start_time": "2021-12-09T16:46:30.360608Z"
|
114 |
+
},
|
115 |
+
"pycharm": {
|
116 |
+
"name": "#%%\n"
|
117 |
+
}
|
118 |
+
},
|
119 |
+
"outputs": [],
|
120 |
+
"source": [
|
121 |
+
"movies_df = pd.read_csv(\n",
|
122 |
+
" data_dir / \"IMDb movies.csv\", usecols=[\"imdb_title_id\", \"description\"]\n",
|
123 |
+
")\n",
|
124 |
+
"movies_df = movies_df.rename(columns={\"imdb_title_id\": \"id\", \"description\": \"text\"})\n",
|
125 |
+
"movies_df.dropna(subset=[\"text\"], inplace=True) # Drop rows where text is empty\n",
|
126 |
+
"movies_df.head()\n"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"execution_count": null,
|
132 |
+
"id": "27f7fd94",
|
133 |
+
"metadata": {
|
134 |
+
"ExecuteTime": {
|
135 |
+
"end_time": "2021-12-09T16:46:30.792761Z",
|
136 |
+
"start_time": "2021-12-09T16:46:30.781964Z"
|
137 |
+
},
|
138 |
+
"pycharm": {
|
139 |
+
"name": "#%%\n"
|
140 |
+
}
|
141 |
+
},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"ids_with_plots = movies_df.id.tolist()"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "code",
|
149 |
+
"execution_count": null,
|
150 |
+
"id": "ebaa042a",
|
151 |
+
"metadata": {
|
152 |
+
"ExecuteTime": {
|
153 |
+
"end_time": "2021-12-09T16:47:04.704390Z",
|
154 |
+
"start_time": "2021-12-09T16:46:30.794094Z"
|
155 |
+
},
|
156 |
+
"pycharm": {
|
157 |
+
"name": "#%%\n"
|
158 |
+
}
|
159 |
+
},
|
160 |
+
"outputs": [],
|
161 |
+
"source": [
|
162 |
+
"images_list = copy_images(data_dir / \"Poster\", images_dir, ids_with_plots)\n",
|
163 |
+
"images_list[0]"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": null,
|
169 |
+
"id": "17e0a874",
|
170 |
+
"metadata": {
|
171 |
+
"ExecuteTime": {
|
172 |
+
"end_time": "2021-12-09T16:47:04.724427Z",
|
173 |
+
"start_time": "2021-12-09T16:47:04.705540Z"
|
174 |
+
},
|
175 |
+
"pycharm": {
|
176 |
+
"name": "#%%\n"
|
177 |
+
}
|
178 |
+
},
|
179 |
+
"outputs": [],
|
180 |
+
"source": [
|
181 |
+
"images_df = pd.DataFrame(images_list, columns=[\"id\", \"filename\"])\n",
|
182 |
+
"images_df.head()"
|
183 |
+
]
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"cell_type": "code",
|
187 |
+
"execution_count": null,
|
188 |
+
"id": "bb1114e6",
|
189 |
+
"metadata": {
|
190 |
+
"ExecuteTime": {
|
191 |
+
"end_time": "2021-12-09T16:47:04.772775Z",
|
192 |
+
"start_time": "2021-12-09T16:47:04.725707Z"
|
193 |
+
},
|
194 |
+
"pycharm": {
|
195 |
+
"name": "#%%\n"
|
196 |
+
}
|
197 |
+
},
|
198 |
+
"outputs": [],
|
199 |
+
"source": [
|
200 |
+
"data_df = pd.merge(movies_df, images_df, on=[\"id\"])\n",
|
201 |
+
"print(len(data_df))\n",
|
202 |
+
"data_df"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": null,
|
208 |
+
"id": "6790815b",
|
209 |
+
"metadata": {
|
210 |
+
"ExecuteTime": {
|
211 |
+
"end_time": "2021-12-09T16:47:04.796785Z",
|
212 |
+
"start_time": "2021-12-09T16:47:04.774932Z"
|
213 |
+
},
|
214 |
+
"pycharm": {
|
215 |
+
"name": "#%%\n"
|
216 |
+
}
|
217 |
+
},
|
218 |
+
"outputs": [],
|
219 |
+
"source": [
|
220 |
+
"print(len(data_df))\n",
|
221 |
+
"data_df.dropna(subset=[\"filename\"], inplace=True)\n",
|
222 |
+
"print(len(data_df))"
|
223 |
+
]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"cell_type": "code",
|
227 |
+
"execution_count": null,
|
228 |
+
"id": "40c7205d",
|
229 |
+
"metadata": {
|
230 |
+
"ExecuteTime": {
|
231 |
+
"end_time": "2021-12-09T16:47:04.818522Z",
|
232 |
+
"start_time": "2021-12-09T16:47:04.798063Z"
|
233 |
+
},
|
234 |
+
"pycharm": {
|
235 |
+
"name": "#%%\n"
|
236 |
+
}
|
237 |
+
},
|
238 |
+
"outputs": [],
|
239 |
+
"source": [
|
240 |
+
"print(len(data_df))\n",
|
241 |
+
"data_df.dropna(subset=[\"text\"], inplace=True)\n",
|
242 |
+
"print(len(data_df))"
|
243 |
+
]
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"cell_type": "code",
|
247 |
+
"execution_count": null,
|
248 |
+
"id": "9a2d142f",
|
249 |
+
"metadata": {
|
250 |
+
"ExecuteTime": {
|
251 |
+
"end_time": "2021-12-09T16:47:04.838450Z",
|
252 |
+
"start_time": "2021-12-09T16:47:04.819726Z"
|
253 |
+
},
|
254 |
+
"pycharm": {
|
255 |
+
"name": "#%%\n"
|
256 |
+
}
|
257 |
+
},
|
258 |
+
"outputs": [],
|
259 |
+
"source": [
|
260 |
+
"print(len(data_df))\n",
|
261 |
+
"data_df.drop_duplicates(subset=[\"id\"], inplace=True)\n",
|
262 |
+
"print(len(data_df))"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"execution_count": null,
|
268 |
+
"id": "45f4b970",
|
269 |
+
"metadata": {
|
270 |
+
"ExecuteTime": {
|
271 |
+
"end_time": "2021-12-09T16:47:04.971652Z",
|
272 |
+
"start_time": "2021-12-09T16:47:04.839618Z"
|
273 |
+
},
|
274 |
+
"pycharm": {
|
275 |
+
"name": "#%%\n"
|
276 |
+
}
|
277 |
+
},
|
278 |
+
"outputs": [],
|
279 |
+
"source": [
|
280 |
+
"data_df.to_csv(data_dir / \"data.csv\", index=False)"
|
281 |
+
]
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"cell_type": "code",
|
285 |
+
"execution_count": null,
|
286 |
+
"id": "f8019a02",
|
287 |
+
"metadata": {
|
288 |
+
"ExecuteTime": {
|
289 |
+
"end_time": "2021-12-09T16:47:05.104710Z",
|
290 |
+
"start_time": "2021-12-09T16:47:04.972681Z"
|
291 |
+
},
|
292 |
+
"pycharm": {
|
293 |
+
"name": "#%%\n"
|
294 |
+
}
|
295 |
+
},
|
296 |
+
"outputs": [],
|
297 |
+
"source": [
|
298 |
+
"train_df, valid_df = train_test_split(data_df, test_size=0.1, shuffle=True)\n",
|
299 |
+
"train_df.to_csv(data_dir / \"train.csv\", index=False)\n",
|
300 |
+
"valid_df.to_csv(data_dir / \"valid.csv\", index=False)\n",
|
301 |
+
"print(len(train_df), len(valid_df))"
|
302 |
+
]
|
303 |
+
}
|
304 |
+
],
|
305 |
+
"metadata": {
|
306 |
+
"kernelspec": {
|
307 |
+
"display_name": "huggingface",
|
308 |
+
"language": "python",
|
309 |
+
"name": "huggingface"
|
310 |
+
},
|
311 |
+
"language_info": {
|
312 |
+
"codemirror_mode": {
|
313 |
+
"name": "ipython",
|
314 |
+
"version": 3
|
315 |
+
},
|
316 |
+
"file_extension": ".py",
|
317 |
+
"mimetype": "text/x-python",
|
318 |
+
"name": "python",
|
319 |
+
"nbconvert_exporter": "python",
|
320 |
+
"pygments_lexer": "ipython3",
|
321 |
+
"version": "3.9.7"
|
322 |
+
}
|
323 |
+
},
|
324 |
+
"nbformat": 4,
|
325 |
+
"nbformat_minor": 5
|
326 |
+
}
|
train/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--find-links https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
pandas==1.3.4
|
3 |
+
scikit-learn==1.0.1
|
4 |
+
python-box==5.4.1
|
5 |
+
transformers==4.12.5
|
6 |
+
torch==1.10.0+cu113
|
7 |
+
Pillow==8.4.0
|
train/train.ipynb
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "0fbed7bc",
|
7 |
+
"metadata": {
|
8 |
+
"ExecuteTime": {
|
9 |
+
"end_time": "2021-12-09T15:34:14.921553Z",
|
10 |
+
"start_time": "2021-12-09T15:34:14.911112Z"
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"outputs": [],
|
14 |
+
"source": [
|
15 |
+
"%reload_ext autoreload\n",
|
16 |
+
"%autoreload 2"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": null,
|
22 |
+
"id": "c4b60ef3",
|
23 |
+
"metadata": {
|
24 |
+
"ExecuteTime": {
|
25 |
+
"end_time": "2021-12-09T15:34:15.961098Z",
|
26 |
+
"start_time": "2021-12-09T15:34:14.922771Z"
|
27 |
+
},
|
28 |
+
"code_folding": []
|
29 |
+
},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"# imports\n",
|
33 |
+
"\n",
|
34 |
+
"import pandas as pd\n",
|
35 |
+
"import os\n",
|
36 |
+
"from pathlib import Path\n",
|
37 |
+
"from PIL import Image\n",
|
38 |
+
"import shutil\n",
|
39 |
+
"from logging import root\n",
|
40 |
+
"from PIL import Image\n",
|
41 |
+
"from pathlib import Path\n",
|
42 |
+
"import pandas as pd\n",
|
43 |
+
"import torch\n",
|
44 |
+
"from torch.utils.data import Dataset\n",
|
45 |
+
"from PIL import Image\n",
|
46 |
+
"from transformers import (\n",
|
47 |
+
" Seq2SeqTrainer,\n",
|
48 |
+
" Seq2SeqTrainingArguments,\n",
|
49 |
+
" get_linear_schedule_with_warmup,\n",
|
50 |
+
" AutoFeatureExtractor,\n",
|
51 |
+
" AutoTokenizer,\n",
|
52 |
+
" ViTFeatureExtractor,\n",
|
53 |
+
" VisionEncoderDecoderModel,\n",
|
54 |
+
" default_data_collator,\n",
|
55 |
+
")\n",
|
56 |
+
"from transformers.optimization import AdamW\n",
|
57 |
+
"\n",
|
58 |
+
"from box import Box\n",
|
59 |
+
"import inspect\n"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": null,
|
65 |
+
"id": "99d6f14d",
|
66 |
+
"metadata": {
|
67 |
+
"ExecuteTime": {
|
68 |
+
"end_time": "2021-12-09T15:34:15.979191Z",
|
69 |
+
"start_time": "2021-12-09T15:34:15.962078Z"
|
70 |
+
},
|
71 |
+
"code_folding": []
|
72 |
+
},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"# custom functions\n",
|
76 |
+
"\n",
|
77 |
+
"class ImageCaptionDataset(Dataset):\n",
|
78 |
+
" def __init__(\n",
|
79 |
+
" self, df, feature_extractor, tokenizer, images_dir, max_target_length=128\n",
|
80 |
+
" ):\n",
|
81 |
+
" self.df = df\n",
|
82 |
+
" self.feature_extractor = feature_extractor\n",
|
83 |
+
" self.tokenizer = tokenizer\n",
|
84 |
+
" self.images_dir = images_dir\n",
|
85 |
+
" self.max_target_length = max_target_length\n",
|
86 |
+
"\n",
|
87 |
+
" def __len__(self):\n",
|
88 |
+
" return len(self.df)\n",
|
89 |
+
"\n",
|
90 |
+
" def __getitem__(self, idx):\n",
|
91 |
+
" filename = self.df[\"filename\"][idx]\n",
|
92 |
+
" text = self.df[\"text\"][idx]\n",
|
93 |
+
" # prepare image (i.e. resize + normalize)\n",
|
94 |
+
" image = Image.open(self.images_dir / filename).convert(\"RGB\")\n",
|
95 |
+
" pixel_values = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n",
|
96 |
+
" # add labels (input_ids) by encoding the text\n",
|
97 |
+
" labels = self.tokenizer(\n",
|
98 |
+
" text,\n",
|
99 |
+
" padding=\"max_length\",\n",
|
100 |
+
" truncation=True,\n",
|
101 |
+
" max_length=self.max_target_length,\n",
|
102 |
+
" ).input_ids\n",
|
103 |
+
" # important: make sure that PAD tokens are ignored by the loss function\n",
|
104 |
+
" labels = [\n",
|
105 |
+
" label if label != self.tokenizer.pad_token_id else -100 for label in labels\n",
|
106 |
+
" ]\n",
|
107 |
+
"\n",
|
108 |
+
" encoding = {\n",
|
109 |
+
" \"pixel_values\": pixel_values.squeeze(),\n",
|
110 |
+
" \"labels\": torch.tensor(labels),\n",
|
111 |
+
" }\n",
|
112 |
+
" return encoding\n",
|
113 |
+
"\n",
|
114 |
+
"\n",
|
115 |
+
"\n",
|
116 |
+
"def predict(image, max_length=64, num_beams=4):\n",
|
117 |
+
"\n",
|
118 |
+
" pixel_values = feature_extractor(images=image, return_tensors=\"pt\").pixel_values\n",
|
119 |
+
" pixel_values = pixel_values.to(device)\n",
|
120 |
+
"\n",
|
121 |
+
" with torch.no_grad():\n",
|
122 |
+
" output_ids = model.generate(\n",
|
123 |
+
" pixel_values,\n",
|
124 |
+
" max_length=max_length,\n",
|
125 |
+
" num_beams=num_beams,\n",
|
126 |
+
" return_dict_in_generate=True,\n",
|
127 |
+
" ).sequences\n",
|
128 |
+
"\n",
|
129 |
+
" preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)\n",
|
130 |
+
" preds = [pred.strip() for pred in preds]\n",
|
131 |
+
"\n",
|
132 |
+
" return preds\n"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"execution_count": null,
|
138 |
+
"id": "ea66826b",
|
139 |
+
"metadata": {
|
140 |
+
"ExecuteTime": {
|
141 |
+
"end_time": "2021-12-09T15:34:16.042990Z",
|
142 |
+
"start_time": "2021-12-09T15:34:15.980557Z"
|
143 |
+
}
|
144 |
+
},
|
145 |
+
"outputs": [],
|
146 |
+
"source": [
|
147 |
+
"data_dir = Path(\"datasets\").resolve()\n",
|
148 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
149 |
+
"print(device)"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"execution_count": null,
|
155 |
+
"id": "17cfb2c2",
|
156 |
+
"metadata": {
|
157 |
+
"ExecuteTime": {
|
158 |
+
"end_time": "2021-12-09T15:34:16.058421Z",
|
159 |
+
"start_time": "2021-12-09T15:34:16.044111Z"
|
160 |
+
}
|
161 |
+
},
|
162 |
+
"outputs": [],
|
163 |
+
"source": [
|
164 |
+
"# arguments pertaining to what data we are going to input our model for training and eval.\n",
|
165 |
+
"\n",
|
166 |
+
"data_training_args = {\n",
|
167 |
+
" # The maximum total sequence length for target text after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.\n",
|
168 |
+
" \"max_target_length\": 64,\n",
|
169 |
+
"\n",
|
170 |
+
" # Number of beams to use for evaluation. This argument will be passed to model.generate which is used during evaluate and predict.\n",
|
171 |
+
" \"num_beams\": 4,\n",
|
172 |
+
"\n",
|
173 |
+
" # Folder with all the images\n",
|
174 |
+
" \"images_dir\": data_dir / \"images\",\n",
|
175 |
+
"}\n",
|
176 |
+
"\n",
|
177 |
+
"data_training_args = Box(data_training_args)"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": null,
|
183 |
+
"id": "adc4839a",
|
184 |
+
"metadata": {
|
185 |
+
"ExecuteTime": {
|
186 |
+
"end_time": "2021-12-09T15:34:16.073242Z",
|
187 |
+
"start_time": "2021-12-09T15:34:16.059354Z"
|
188 |
+
}
|
189 |
+
},
|
190 |
+
"outputs": [],
|
191 |
+
"source": [
|
192 |
+
"# arguments pertaining to which model/config/tokenizer we are going to fine-tune from.\n",
|
193 |
+
"\n",
|
194 |
+
"model_args = {\n",
|
195 |
+
"\n",
|
196 |
+
" # Path to pretrained model or model identifier from huggingface.co/models\"\n",
|
197 |
+
" \"encoder_model_name_or_path\": \"google/vit-base-patch16-224-in21k\",\n",
|
198 |
+
"\n",
|
199 |
+
" # Path to pretrained model or model identifier from huggingface.co/models\"\n",
|
200 |
+
" \"decoder_model_name_or_path\": \"gpt2\",\n",
|
201 |
+
"\n",
|
202 |
+
" # If set to int > 0, all ngrams of that size can only occur once.\n",
|
203 |
+
" \"no_repeat_ngram_size\": 3,\n",
|
204 |
+
"\n",
|
205 |
+
" # Exponential penalty to the length that will be used by default in the generate method of the model.\n",
|
206 |
+
" \"length_penalty\": 2.0,\n",
|
207 |
+
"}\n",
|
208 |
+
"\n",
|
209 |
+
"model_args = Box(model_args)"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": null,
|
215 |
+
"id": "22b8c9e3",
|
216 |
+
"metadata": {
|
217 |
+
"ExecuteTime": {
|
218 |
+
"end_time": "2021-12-09T15:34:16.089201Z",
|
219 |
+
"start_time": "2021-12-09T15:34:16.074223Z"
|
220 |
+
}
|
221 |
+
},
|
222 |
+
"outputs": [],
|
223 |
+
"source": [
|
224 |
+
"# arguments pertaining to Trainer class. Refer: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments\n",
|
225 |
+
"\n",
|
226 |
+
"training_args = {\n",
|
227 |
+
" \"num_train_epochs\": 5,\n",
|
228 |
+
" \"per_device_train_batch_size\": 32,\n",
|
229 |
+
" \"per_device_eval_batch_size\": 32,\n",
|
230 |
+
" \"output_dir\": \"output_dir\",\n",
|
231 |
+
" \"do_train\": True,\n",
|
232 |
+
" \"do_eval\": True,\n",
|
233 |
+
" \"fp16\": True,\n",
|
234 |
+
" \"learning_rate\": 1e-5,\n",
|
235 |
+
" \"load_best_model_at_end\": True,\n",
|
236 |
+
" \"evaluation_strategy\": \"epoch\",\n",
|
237 |
+
" \"save_strategy\": \"epoch\",\n",
|
238 |
+
" \"report_to\": \"none\"\n",
|
239 |
+
"}\n",
|
240 |
+
"\n",
|
241 |
+
"seq2seq_training_args = Seq2SeqTrainingArguments(**training_args)"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "code",
|
246 |
+
"execution_count": null,
|
247 |
+
"id": "d0023eac",
|
248 |
+
"metadata": {
|
249 |
+
"ExecuteTime": {
|
250 |
+
"end_time": "2021-12-09T15:34:37.844396Z",
|
251 |
+
"start_time": "2021-12-09T15:34:16.090085Z"
|
252 |
+
}
|
253 |
+
},
|
254 |
+
"outputs": [],
|
255 |
+
"source": [
|
256 |
+
"feature_extractor = ViTFeatureExtractor.from_pretrained(\n",
|
257 |
+
" model_args.encoder_model_name_or_path\n",
|
258 |
+
")\n",
|
259 |
+
"tokenizer = AutoTokenizer.from_pretrained(\n",
|
260 |
+
" model_args.decoder_model_name_or_path, use_fast=True\n",
|
261 |
+
")\n",
|
262 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
263 |
+
"\n",
|
264 |
+
"model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(\n",
|
265 |
+
" model_args.encoder_model_name_or_path, model_args.decoder_model_name_or_path\n",
|
266 |
+
")\n",
|
267 |
+
"\n",
|
268 |
+
"# set special tokens used for creating the decoder_input_ids from the labels\n",
|
269 |
+
"model.config.decoder_start_token_id = tokenizer.bos_token_id\n",
|
270 |
+
"model.config.pad_token_id = tokenizer.pad_token_id\n",
|
271 |
+
"# make sure vocab size is set correctly\n",
|
272 |
+
"model.config.vocab_size = model.config.decoder.vocab_size\n",
|
273 |
+
"\n",
|
274 |
+
"# set beam search parameters\n",
|
275 |
+
"model.config.eos_token_id = tokenizer.sep_token_id\n",
|
276 |
+
"model.config.max_length = data_training_args.max_target_length\n",
|
277 |
+
"model.config.no_repeat_ngram_size = model_args.no_repeat_ngram_size\n",
|
278 |
+
"model.config.length_penalty = model_args.length_penalty\n",
|
279 |
+
"model.config.num_beams = data_training_args.num_beams\n",
|
280 |
+
"model.decoder.resize_token_embeddings(len(tokenizer))\n"
|
281 |
+
]
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"cell_type": "code",
|
285 |
+
"execution_count": null,
|
286 |
+
"id": "6428ea08",
|
287 |
+
"metadata": {
|
288 |
+
"ExecuteTime": {
|
289 |
+
"end_time": "2021-12-09T15:34:37.933804Z",
|
290 |
+
"start_time": "2021-12-09T15:34:37.845607Z"
|
291 |
+
}
|
292 |
+
},
|
293 |
+
"outputs": [],
|
294 |
+
"source": [
|
295 |
+
"train_df = pd.read_csv(data_dir / \"train.csv\")\n",
|
296 |
+
"valid_df = pd.read_csv(data_dir / \"valid.csv\")\n",
|
297 |
+
"\n",
|
298 |
+
"train_dataset = ImageCaptionDataset(\n",
|
299 |
+
" df=train_df,\n",
|
300 |
+
" feature_extractor=feature_extractor,\n",
|
301 |
+
" tokenizer=tokenizer,\n",
|
302 |
+
" images_dir=data_training_args.images_dir,\n",
|
303 |
+
" max_target_length=data_training_args.max_target_length,\n",
|
304 |
+
")\n",
|
305 |
+
"eval_dataset = ImageCaptionDataset(\n",
|
306 |
+
" df=valid_df,\n",
|
307 |
+
" feature_extractor=feature_extractor,\n",
|
308 |
+
" tokenizer=tokenizer,\n",
|
309 |
+
" images_dir=data_training_args.images_dir,\n",
|
310 |
+
" max_target_length=data_training_args.max_target_length,\n",
|
311 |
+
")\n",
|
312 |
+
"\n",
|
313 |
+
"print(f\"Number of training examples: {len(train_dataset)}\")\n",
|
314 |
+
"print(f\"Number of validation examples: {len(eval_dataset)}\")"
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"cell_type": "code",
|
319 |
+
"execution_count": null,
|
320 |
+
"id": "c8e492a1",
|
321 |
+
"metadata": {
|
322 |
+
"ExecuteTime": {
|
323 |
+
"end_time": "2021-12-09T15:34:37.971630Z",
|
324 |
+
"start_time": "2021-12-09T15:34:37.935339Z"
|
325 |
+
}
|
326 |
+
},
|
327 |
+
"outputs": [],
|
328 |
+
"source": [
|
329 |
+
"# Let's verify an example from the training dataset:\n",
|
330 |
+
"\n",
|
331 |
+
"encoding = train_dataset[0]\n",
|
332 |
+
"for k,v in encoding.items():\n",
|
333 |
+
" print(k, v.shape)"
|
334 |
+
]
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"cell_type": "code",
|
338 |
+
"execution_count": null,
|
339 |
+
"id": "edb4e7a6",
|
340 |
+
"metadata": {
|
341 |
+
"ExecuteTime": {
|
342 |
+
"end_time": "2021-12-09T15:34:38.006980Z",
|
343 |
+
"start_time": "2021-12-09T15:34:37.972483Z"
|
344 |
+
}
|
345 |
+
},
|
346 |
+
"outputs": [],
|
347 |
+
"source": [
|
348 |
+
"# We can also check the original image and decode the labels:\n",
|
349 |
+
"image = Image.open(data_training_args.images_dir / train_df[\"filename\"][0]).convert(\"RGB\")\n",
|
350 |
+
"image"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": null,
|
356 |
+
"id": "25f2cae7",
|
357 |
+
"metadata": {
|
358 |
+
"ExecuteTime": {
|
359 |
+
"end_time": "2021-12-09T15:34:38.031745Z",
|
360 |
+
"start_time": "2021-12-09T15:34:38.008027Z"
|
361 |
+
}
|
362 |
+
},
|
363 |
+
"outputs": [],
|
364 |
+
"source": [
|
365 |
+
"labels = encoding[\"labels\"]\n",
|
366 |
+
"labels[labels == -100] = tokenizer.pad_token_id\n",
|
367 |
+
"label_str = tokenizer.decode(labels, skip_special_tokens=True)\n",
|
368 |
+
"print(label_str)\n"
|
369 |
+
]
|
370 |
+
},
|
371 |
+
{
|
372 |
+
"cell_type": "code",
|
373 |
+
"execution_count": null,
|
374 |
+
"id": "b7a009d3",
|
375 |
+
"metadata": {
|
376 |
+
"ExecuteTime": {
|
377 |
+
"end_time": "2021-12-09T15:34:38.049539Z",
|
378 |
+
"start_time": "2021-12-09T15:34:38.032749Z"
|
379 |
+
}
|
380 |
+
},
|
381 |
+
"outputs": [],
|
382 |
+
"source": [
|
383 |
+
"optimizer = AdamW(model.parameters(), lr=seq2seq_training_args.learning_rate)\n",
|
384 |
+
"\n",
|
385 |
+
"steps_per_epoch = len(train_dataset) // seq2seq_training_args.per_device_train_batch_size\n",
|
386 |
+
"num_training_steps = steps_per_epoch * seq2seq_training_args.num_train_epochs\n",
|
387 |
+
"\n",
|
388 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
389 |
+
" optimizer,\n",
|
390 |
+
" num_warmup_steps=seq2seq_training_args.warmup_steps,\n",
|
391 |
+
" num_training_steps=num_training_steps,\n",
|
392 |
+
")\n",
|
393 |
+
"\n",
|
394 |
+
"optimizers = (optimizer, lr_scheduler)"
|
395 |
+
]
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"cell_type": "code",
|
399 |
+
"execution_count": null,
|
400 |
+
"id": "f2f477b2",
|
401 |
+
"metadata": {
|
402 |
+
"ExecuteTime": {
|
403 |
+
"start_time": "2021-12-09T15:34:14.944Z"
|
404 |
+
}
|
405 |
+
},
|
406 |
+
"outputs": [],
|
407 |
+
"source": [
|
408 |
+
"trainer = Seq2SeqTrainer(\n",
|
409 |
+
" model=model,\n",
|
410 |
+
" optimizers=optimizers,\n",
|
411 |
+
" tokenizer=feature_extractor,\n",
|
412 |
+
" args=seq2seq_training_args,\n",
|
413 |
+
" train_dataset=train_dataset,\n",
|
414 |
+
" eval_dataset=eval_dataset,\n",
|
415 |
+
" data_collator=default_data_collator,\n",
|
416 |
+
")\n",
|
417 |
+
"\n",
|
418 |
+
"trainer.train()"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"cell_type": "code",
|
423 |
+
"execution_count": null,
|
424 |
+
"id": "f08d2b7c",
|
425 |
+
"metadata": {
|
426 |
+
"ExecuteTime": {
|
427 |
+
"end_time": "2021-12-09T16:24:49.096274Z",
|
428 |
+
"start_time": "2021-12-09T16:24:49.096246Z"
|
429 |
+
}
|
430 |
+
},
|
431 |
+
"outputs": [],
|
432 |
+
"source": [
|
433 |
+
"test_img = \"../examples/tt7991608-red-notice.jpg\"\n",
|
434 |
+
"with Image.open(test_img) as image:\n",
|
435 |
+
" preds = predict(\n",
|
436 |
+
" image, max_length=data_training_args.max_target_length, num_beams=data_training_args.num_beams\n",
|
437 |
+
" )\n",
|
438 |
+
"\n",
|
439 |
+
"# Uncomment to display the test image in a jupyter notebook\n",
|
440 |
+
"# display(image)\n",
|
441 |
+
"print(preds[0])"
|
442 |
+
]
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"cell_type": "code",
|
446 |
+
"execution_count": null,
|
447 |
+
"id": "ecf21225",
|
448 |
+
"metadata": {},
|
449 |
+
"outputs": [],
|
450 |
+
"source": []
|
451 |
+
}
|
452 |
+
],
|
453 |
+
"metadata": {
|
454 |
+
"kernelspec": {
|
455 |
+
"display_name": "huggingface",
|
456 |
+
"language": "python",
|
457 |
+
"name": "huggingface"
|
458 |
+
},
|
459 |
+
"language_info": {
|
460 |
+
"codemirror_mode": {
|
461 |
+
"name": "ipython",
|
462 |
+
"version": 3
|
463 |
+
},
|
464 |
+
"file_extension": ".py",
|
465 |
+
"mimetype": "text/x-python",
|
466 |
+
"name": "python",
|
467 |
+
"nbconvert_exporter": "python",
|
468 |
+
"pygments_lexer": "ipython3",
|
469 |
+
"version": "3.9.7"
|
470 |
+
}
|
471 |
+
},
|
472 |
+
"nbformat": 4,
|
473 |
+
"nbformat_minor": 5
|
474 |
+
}
|