iakarshu commited on
Commit
f8ffe75
·
1 Parent(s): d7f6f38

Upload dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +127 -0
dataset.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%writefile dataset.py
2
+ import pandas as pd
3
+ import os
4
+ from PIL import Image, ImageDraw
5
+ import numpy as np
6
+ import pytesseract
7
+ import torch
8
+
9
+ ## I guess, I got my own script for it, from https://github.com/shabie/docformer/blob/master/src/docformer/dataset.py
10
+
11
+ def rescale_bbox(bbox, img_width : int,
12
+ img_height : int, size : int = 1000):
13
+ x0, x1, y0, y1, width, height = bbox
14
+ x0 = int(size * (x0 / img_width))
15
+ x1 = int(size * (x1 / img_width))
16
+ y0 = int(size * (y0 / img_height))
17
+ y1 = int(size * (y1 / img_height))
18
+ width = int(size * (width / img_width))
19
+ height = int(size * (height / img_height))
20
+ return [x0, x1, y0, y1, width, height]
21
+
22
+ def coordinate_features(df_row):
23
+ xmin, ymin, width, height = df_row["left"], df_row["top"], df_row["width"], df_row["height"]
24
+ return [xmin, xmin + width, ymin, ymin + height, width, height] ## [xmin, xmax, ymin, ymax, width, height]
25
+
26
+ def get_ocr_results(image_path : str):
27
+
28
+ """
29
+ Returns words and its bounding boxes from the image file path
30
+ image_path: string containing the path of the image
31
+ """
32
+
33
+ ## Getting the Image
34
+ image = Image.open(image_path)
35
+ width, height = image.size
36
+
37
+ ## OCR Processing
38
+ ocr_df = pytesseract.image_to_data(image, output_type="data.frame")
39
+ ocr_df = ocr_df.dropna().reset_index(drop=True)
40
+ float_cols = ocr_df.select_dtypes("float").columns
41
+ ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
42
+ ocr_df = ocr_df.replace(r"^\s*$", np.nan, regex=True)
43
+ ocr_df = ocr_df.dropna().reset_index(drop=True)
44
+ ocr_df = ocr_df.sort_values(by=['left', 'top']) ## Sorting the values on the basis of left, top bounding box coordinates
45
+
46
+ ## Finally getting the words and the bounding box
47
+ words = list(ocr_df.text.apply(lambda x: str(x).strip()))
48
+ actual_bboxes = ocr_df.apply(coordinate_features, axis=1).values.tolist()
49
+
50
+ # add as extra columns
51
+ assert len(words) == len(actual_bboxes)
52
+ return {"words": words, "bbox": actual_bboxes}
53
+
54
+ ## Stealed from here: https://github.com/uakarsh/latr/blob/main/src/latr/dataset.py
55
+
56
+ def get_tokens_with_boxes(unnormalized_word_boxes, list_of_words,
57
+ tokenizer, pad_token_id : int = 0,
58
+ pad_token_box = [0, 0, 0, 0, 0, 0],
59
+ max_seq_len = 512,
60
+ sep_token_box = [0, 0, 1000, 1000, 0, 0]
61
+ ):
62
+
63
+ '''
64
+ This function returns two items:
65
+ 1. unnormalized_token_boxes -> a list of len = max_seq_len, containing the boxes corresponding to the tokenized words,
66
+ one box might repeat as per the tokenization procedure
67
+ 2. tokenized_words -> tokenized words corresponding to the tokenizer and the list_of_words
68
+ '''
69
+
70
+ assert len(unnormalized_word_boxes) == len(list_of_words), "Bounding box length != total words length"
71
+
72
+ length_of_box = len(unnormalized_word_boxes)
73
+ unnormalized_token_boxes = []
74
+ tokenized_words = []
75
+
76
+ ## CLS, SEP tokens have to be appended
77
+ unnormalized_token_boxes.extend([pad_token_box])
78
+ tokenized_words.extend([tokenizer.cls_token_id]) ## CLS Token Box is same as pad_token_box, if not, you can change here
79
+
80
+ ## Normal for loop
81
+ idx = 0
82
+ for box, word in zip(unnormalized_word_boxes, list_of_words):
83
+ if idx != 0:
84
+ new_word = " " + word
85
+ else:
86
+ new_word = word
87
+ current_tokens = tokenizer(new_word, add_special_tokens = False).input_ids
88
+ unnormalized_token_boxes.extend([box]*len(current_tokens))
89
+ tokenized_words.extend(current_tokens)
90
+ idx += 1
91
+
92
+ ## For post processing the token box
93
+ if len(unnormalized_token_boxes)<max_seq_len:
94
+ unnormalized_token_boxes.extend([sep_token_box])
95
+ unnormalized_token_boxes.extend([pad_token_box] * (max_seq_len-len(unnormalized_token_boxes)))
96
+
97
+ else:
98
+ unnormalized_token_boxes[max_seq_len - 1] = sep_token_box
99
+
100
+ ## For post processing the tokenized words
101
+ if len(tokenized_words) < max_seq_len:
102
+ tokenized_words.extend([tokenizer.sep_token_id])
103
+ tokenized_words.extend([pad_token_id]* (max_seq_len-len(tokenized_words)))
104
+
105
+ else:
106
+ tokenized_words[max_seq_len - 1] = tokenizer.sep_token_id
107
+
108
+ return unnormalized_token_boxes[:max_seq_len], tokenized_words[:max_seq_len]
109
+
110
+
111
+ def create_features(
112
+ img_path : str,
113
+ tokenizer,
114
+ max_seq_length : int = 512,
115
+ size : int = 1000,
116
+ use_ocr : bool = True,
117
+ bounding_box = None,
118
+ words = None ):
119
+
120
+ image = Image.open(img_path).convert("RGB")
121
+ ocr_results = get_ocr_results(img_path)
122
+ ocr_results['rescale_bbox'] = list(map(lambda x: rescale_bbox(x, *image.size, size = size), ocr_results['bbox']))
123
+ boxes, words = get_tokens_with_boxes(ocr_results['rescale_bbox'], ocr_results['words'], tokenizer)
124
+ torch_boxes = torch.as_tensor(boxes)
125
+ torch_words = torch.as_tensor(words)
126
+
127
+ return torch_boxes, torch_words, ocr_results['bbox']