File size: 9,069 Bytes
6df3c38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee7e9d0
 
6df3c38
 
ee7e9d0
 
 
 
6df3c38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from .utils import load_model,load_processor,normalize_box,compare_boxes,adjacent
from .annotate_image import get_flattened_output,annotate_image
from PIL import Image,ImageDraw, ImageFont
import logging
import torch
import json
import os


logger = logging.getLogger(__name__)

class ModelHandler(object):
    def __init__(self):
        self.model = None
        self.model_dir = None
        self.device = 'cpu'
        self.error = None
        self.initialized = False
        self._raw_input_data = None
        self._processed_data = None
        self._images_size = None
        self.counter = 0


    def initialize(self, context):
        try:
            logger.info("Loading transformer model")
            self._context = context
            properties = self._context
            self.model_dir = properties.get("model_dir")
            self.model = self.load(self.model_dir)
            self.initialized = True
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            self.error = str(e)

    def preprocess(self, batch):
        try:
            inference_dict = batch
            self._raw_input_data = inference_dict
            processor = load_processor()
            images = [Image.open(path).convert("RGB")
                      for path in inference_dict['image_path']]
            self._images_size = [img.size for img in images]
            words = inference_dict['words']
            boxes = [[normalize_box(box, images[i].size[0], images[i].size[1])
                      for box in doc] for i, doc in enumerate(inference_dict['bboxes'])]
            encoded_inputs = processor(
                images, words, boxes=boxes, return_tensors="pt", padding="max_length", truncation=True)
            self._processed_data = encoded_inputs
            return encoded_inputs
        except Exception as e:
            logger.error(f"Error in preprocessing: {str(e)}")
            self.error = str(e)
            return None

    def load(self, model_dir):
        try:
            model = load_model(model_dir)
            return model
        except Exception as e:
            logger.error(f"Error loading LayoutLMv3 model: {str(e)}")
            self.error = str(e)
            return None

    def inference(self, model_input):
        try:
            with torch.no_grad():
                inference_outputs = self.model(**model_input)
                predictions = inference_outputs.logits.argmax(-1).tolist()
            results = []
            for i in range(len(predictions)):
                tmp = dict()
                tmp[f'output_{i}'] = predictions[i]
                results.append(tmp)
            return [results]
        except Exception as e:
            logger.error(f"Error in inference: {str(e)}")
            self.error = str(e)
            return None

    def postprocess(self, inference_output):
        try:
            docs = []
            k = 0
            for page, doc_words in enumerate(self._raw_input_data['words']):
                doc_list = []
                width, height = self._images_size[page]
                for i, doc_word in enumerate(doc_words, start=0):
                    word_tagging = None
                    word_labels = []
                    word = dict()
                    word['id'] = k
                    k += 1
                    word['text'] = doc_word
                    word['pageNum'] = page + 1
                    word['box'] = self._raw_input_data['bboxes'][page][i]
                    _normalized_box = normalize_box(
                        self._raw_input_data['bboxes'][page][i], width, height)
                    for j, box in enumerate(self._processed_data['bbox'].tolist()[page]):
                        if compare_boxes(box, _normalized_box):
                            if self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]] != 'O':
                                word_labels.append(
                                    self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]][2:])
                            else:
                                word_labels.append('other')
                    if word_labels != []:
                        word_tagging = word_labels[0] if word_labels[0] != 'other' else word_labels[-1]
                    else:
                        word_tagging = 'other'
                    word['label'] = word_tagging
                    word['pageSize'] = {'width': width, 'height': height}
                    if word['label'] != 'other':
                        doc_list.append(word)
                spans = []
                def adjacents(entity): return [
                    adj for adj in doc_list if adjacent(entity, adj)]
                output_test_tmp = doc_list[:]
                for entity in doc_list:
                    if adjacents(entity) == []:
                        spans.append([entity])
                        output_test_tmp.remove(entity)

                while output_test_tmp != []:
                    span = [output_test_tmp[0]]
                    output_test_tmp = output_test_tmp[1:]
                    while output_test_tmp != [] and adjacent(span[-1], output_test_tmp[0]):
                        span.append(output_test_tmp[0])
                        output_test_tmp.remove(output_test_tmp[0])
                    spans.append(span)

                output_spans = []
                for span in spans:
                    if span[0]['label'] in ['ITEMS', 'PRICE']:
                        output_span = {"text": ' '.join([entity['text'] for entity in span]) + '|',
                                    "label": span[0]['label'],
                                    "words": [{
                                        'id': entity['id'],
                                        'box': entity['box'],
                                        'text': entity['text']
                                    } for entity in span]
                                    }
                    else:
                        output_span = {"text": ' '.join([entity['text'] for entity in span]),
                                    "label": span[0]['label'],
                                    "words": [{
                                        'id': entity['id'],
                                        'box': entity['box'],
                                        'text': entity['text']
                                    } for entity in span]
                                    }
                    output_spans.append(output_span)
                docs.append({f'output': output_spans})
            return [json.dumps(docs, ensure_ascii=False)]

        except Exception as e:
            logger.error(f"Error in postprocessing: {str(e)}")
            self.error = str(e)
            return None


    def handle(self, data, context):
        """
        Call preprocess, inference, and post-process functions
        :param data: input data
        :param context: mms context
        """
        try:
            if not self.initialized:
                self.initialize(context)

            if data is None:
                return None

            model_input = self.preprocess(data)
            if model_input is None:
                return None

            model_out = self.inference(model_input)
            if model_out is None:
                return None

            inference_out = self.postprocess(model_out)[0]

            # # Write the original inference output to a JSON file
            # with open('temp/LayoutlMV3InferenceOutput.json', 'w') as inf_out:
            #     inf_out.write(inference_out)

            # Load the original inference output from the JSON file
            inference_out_list = json.loads(inference_out)
            flattened_output_list = get_flattened_output(inference_out_list)
            print('Ready for Annotation')
            for i, flattened_output in enumerate(flattened_output_list):
                annotate_image(data['image_path'][i], flattened_output)

            # Create the labeled directory if it doesn't exist
            labeled_dir = 'static/temp/labeled'
            os.makedirs(labeled_dir, exist_ok=True)

            # Loop through the list and create a new file for each JSON object
            for index, output_json in enumerate(inference_out_list):
                # Generate a filename based on the index
                output_filename = os.path.join(labeled_dir, f'Output_{index}.json')
                # Write the JSON object to the file
                with open(output_filename, 'w') as output_file:
                    json.dump(output_json, output_file)
        
        except Exception as e:
            logger.error(f"Error handling request: {str(e)}")
            self.error = str(e)

                        
_service = ModelHandler()


def handle(data, context):
    if not _service.initialized:
        _service.initialize(context)

    if data is None:
        return None

    return _service.handle(data, context)