Spaces:
Running
Running
put online demo
Browse files- .gitignore +13 -0
- OCR.py +415 -0
- demo_streamlit.py +339 -0
- display.py +181 -0
- eval.py +649 -0
- flask.py +6 -0
- htlm_webpage.py +141 -0
- packages.txt +1 -0
- requirements.txt +10 -0
- toXML.py +351 -0
- train.py +394 -0
- utils.py +936 -0
.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
__pycache__/
|
3 |
+
|
4 |
+
temp/
|
5 |
+
|
6 |
+
|
7 |
+
VISION_KEY.json
|
8 |
+
|
9 |
+
*.pth
|
10 |
+
|
11 |
+
.streamlit/secrets.toml
|
12 |
+
|
13 |
+
backup/
|
OCR.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
from azure.ai.vision.imageanalysis import ImageAnalysisClient
|
4 |
+
from azure.ai.vision.imageanalysis.models import VisualFeatures
|
5 |
+
from azure.core.credentials import AzureKeyCredential
|
6 |
+
import time
|
7 |
+
import numpy as np
|
8 |
+
import networkx as nx
|
9 |
+
from eval import iou
|
10 |
+
from utils import class_dict, proportion_inside
|
11 |
+
import json
|
12 |
+
from utils import rescale_boxes as rescale
|
13 |
+
import streamlit as st
|
14 |
+
|
15 |
+
VISION_KEY = st.secrets["VISION_KEY"]
|
16 |
+
VISION_ENDPOINT = st.secrets["VISION_ENDPOINT"]
|
17 |
+
|
18 |
+
"""
|
19 |
+
#If local execution
|
20 |
+
with open("VISION_KEY.json", "r") as json_file:
|
21 |
+
json_data = json.load(json_file)
|
22 |
+
|
23 |
+
# Step 2: Parse the JSON data (this is done by json.load automatically)
|
24 |
+
VISION_KEY = json_data["VISION_KEY"]
|
25 |
+
VISION_ENDPOINT = json_data["VISION_ENDPOINT"]
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
def sample_ocr_image_file(image_data):
|
30 |
+
# Set the values of your computer vision endpoint and computer vision key
|
31 |
+
# as environment variables:
|
32 |
+
try:
|
33 |
+
endpoint = VISION_ENDPOINT
|
34 |
+
key = VISION_KEY
|
35 |
+
except KeyError:
|
36 |
+
print("Missing environment variable 'VISION_ENDPOINT' or 'VISION_KEY'")
|
37 |
+
print("Set them before running this sample.")
|
38 |
+
exit()
|
39 |
+
|
40 |
+
# Create an Image Analysis client
|
41 |
+
client = ImageAnalysisClient(
|
42 |
+
endpoint=endpoint,
|
43 |
+
credential=AzureKeyCredential(key)
|
44 |
+
)
|
45 |
+
|
46 |
+
# Extract text (OCR) from an image stream. This will be a synchronously (blocking) call.
|
47 |
+
result = client.analyze(
|
48 |
+
image_data=image_data,
|
49 |
+
visual_features=[VisualFeatures.READ]
|
50 |
+
)
|
51 |
+
|
52 |
+
return result
|
53 |
+
|
54 |
+
|
55 |
+
def text_prediction(image):
|
56 |
+
#transform the image into a byte array
|
57 |
+
image.save('temp.jpg')
|
58 |
+
with open('temp.jpg', 'rb') as f:
|
59 |
+
image_data = f.read()
|
60 |
+
ocr_result = sample_ocr_image_file(image_data)
|
61 |
+
#delete the temporary image
|
62 |
+
os.remove('temp.jpg')
|
63 |
+
return ocr_result
|
64 |
+
|
65 |
+
def filter_text(ocr_result, threshold=0.5):
|
66 |
+
words_to_cancel = {"+",".",",","#","@","!","?","(",")","[","]","{","}","<",">","/","\\","|","-","_","=","&","^","%","$","£","€","¥","¢","¤","§","©","®","™","°","±","×","÷","¶","∆","∏","∑","∞","√","∫","≈","≠","≤","≥","≡","∼"}
|
67 |
+
# Add every other one-letter word to the list of words to cancel, except 'I' and 'a'
|
68 |
+
for letter in "bcdefghjklmnopqrstuvwxyz1234567890": # All lowercase letters except 'a'
|
69 |
+
words_to_cancel.add(letter)
|
70 |
+
words_to_cancel.add("i")
|
71 |
+
words_to_cancel.add(letter.upper()) # Add the uppercase version as well
|
72 |
+
characters_to_cancel = {"+", "<", ">"} # Characters to cancel
|
73 |
+
|
74 |
+
list_of_lines = []
|
75 |
+
|
76 |
+
for block in ocr_result['readResult']['blocks']:
|
77 |
+
for line in block['lines']:
|
78 |
+
line_text = []
|
79 |
+
x_min, y_min = float('inf'), float('inf')
|
80 |
+
x_max, y_max = float('-inf'), float('-inf')
|
81 |
+
for word in line['words']:
|
82 |
+
if word['text'] in words_to_cancel or any(disallowed_char in word['text'] for disallowed_char in characters_to_cancel):
|
83 |
+
continue
|
84 |
+
if word['confidence'] > threshold:
|
85 |
+
if word['text']:
|
86 |
+
line_text.append(word['text'])
|
87 |
+
x = [point['x'] for point in word['boundingPolygon']]
|
88 |
+
y = [point['y'] for point in word['boundingPolygon']]
|
89 |
+
x_min = min(x_min, min(x))
|
90 |
+
y_min = min(y_min, min(y))
|
91 |
+
x_max = max(x_max, max(x))
|
92 |
+
y_max = max(y_max, max(y))
|
93 |
+
if line_text: # If there are valid words in the line
|
94 |
+
list_of_lines.append({
|
95 |
+
'text': ' '.join(line_text),
|
96 |
+
'boundingBox': [x_min,y_min,x_max,y_max]
|
97 |
+
})
|
98 |
+
|
99 |
+
list_text = []
|
100 |
+
list_bbox = []
|
101 |
+
for i in range(len(list_of_lines)):
|
102 |
+
list_text.append(list_of_lines[i]['text'])
|
103 |
+
for i in range(len(list_of_lines)):
|
104 |
+
list_bbox.append(list_of_lines[i]['boundingBox'])
|
105 |
+
|
106 |
+
list_of_lines = [list_bbox, list_text]
|
107 |
+
|
108 |
+
return list_of_lines
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def get_box_points(box):
|
114 |
+
"""Returns all critical points of a box: corners and midpoints of edges."""
|
115 |
+
xmin, ymin, xmax, ymax = box
|
116 |
+
return np.array([
|
117 |
+
[xmin, ymin], # Bottom-left corner
|
118 |
+
[xmax, ymin], # Bottom-right corner
|
119 |
+
[xmin, ymax], # Top-left corner
|
120 |
+
[xmax, ymax], # Top-right corner
|
121 |
+
[(xmin + xmax) / 2, ymin], # Midpoint of bottom edge
|
122 |
+
[(xmin + xmax) / 2, ymax], # Midpoint of top edge
|
123 |
+
[xmin, (ymin + ymax) / 2], # Midpoint of left edge
|
124 |
+
[xmax, (ymin + ymax) / 2] # Midpoint of right edge
|
125 |
+
])
|
126 |
+
|
127 |
+
def min_distance_between_boxes(box1, box2):
|
128 |
+
"""Computes the minimum distance between two boxes considering all critical points."""
|
129 |
+
points1 = get_box_points(box1)
|
130 |
+
points2 = get_box_points(box2)
|
131 |
+
|
132 |
+
min_dist = float('inf')
|
133 |
+
for point1 in points1:
|
134 |
+
for point2 in points2:
|
135 |
+
dist = np.linalg.norm(point1 - point2)
|
136 |
+
if dist < min_dist:
|
137 |
+
min_dist = dist
|
138 |
+
return min_dist
|
139 |
+
|
140 |
+
|
141 |
+
def is_inside(box1, box2):
|
142 |
+
"""Check if the center of box1 is inside box2."""
|
143 |
+
x_center = (box1[0] + box1[2]) / 2
|
144 |
+
y_center = (box1[1] + box1[3]) / 2
|
145 |
+
return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3]
|
146 |
+
|
147 |
+
def are_close(box1, box2, threshold=50):
|
148 |
+
"""Determines if boxes are close based on their corners and center points."""
|
149 |
+
corners1 = np.array([
|
150 |
+
[box1[0], box1[1]], [box1[0], box1[3]], [box1[2], box1[1]], [box1[2], box1[3]],
|
151 |
+
[(box1[0]+box1[2])/2, box1[1]], [(box1[0]+box1[2])/2, box1[3]],
|
152 |
+
[box1[0], (box1[1]+box1[3])/2], [box1[2], (box1[1]+box1[3])/2]
|
153 |
+
])
|
154 |
+
corners2 = np.array([
|
155 |
+
[box2[0], box2[1]], [box2[0], box2[3]], [box2[2], box2[1]], [box2[2], box2[3]],
|
156 |
+
[(box2[0]+box2[2])/2, box2[1]], [(box2[0]+box2[2])/2, box2[3]],
|
157 |
+
[box2[0], (box2[1]+box2[3])/2], [box2[2], (box2[1]+box2[3])/2]
|
158 |
+
])
|
159 |
+
for c1 in corners1:
|
160 |
+
for c2 in corners2:
|
161 |
+
if np.linalg.norm(c1 - c2) < threshold:
|
162 |
+
return True
|
163 |
+
return False
|
164 |
+
|
165 |
+
def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
|
166 |
+
"""Find the closest box to the given text box within a specified threshold."""
|
167 |
+
min_distance = float('inf')
|
168 |
+
closest_index = None
|
169 |
+
|
170 |
+
#check if the text is inside a sequenceFlow
|
171 |
+
for j in range(len(all_boxes)):
|
172 |
+
if proportion_inside(text_box, all_boxes[j])>iou_threshold and labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
173 |
+
return j
|
174 |
+
|
175 |
+
for i, box in enumerate(all_boxes):
|
176 |
+
# Compute the center of both boxes
|
177 |
+
center_text = np.array([(text_box[0] + text_box[2]) / 2, (text_box[1] + text_box[3]) / 2])
|
178 |
+
center_box = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2])
|
179 |
+
|
180 |
+
# Calculate Euclidean distance between centers
|
181 |
+
distance = np.linalg.norm(center_text - center_box)
|
182 |
+
|
183 |
+
# Update closest box if this box is nearer
|
184 |
+
if distance < min_distance:
|
185 |
+
min_distance = distance
|
186 |
+
closest_index = i
|
187 |
+
|
188 |
+
# Check if the closest box found is within the acceptable threshold
|
189 |
+
if min_distance < threshold:
|
190 |
+
return closest_index
|
191 |
+
|
192 |
+
return None
|
193 |
+
|
194 |
+
|
195 |
+
def is_vertical(box):
|
196 |
+
"""Determine if the text in the bounding box is vertically aligned."""
|
197 |
+
width = box[2] - box[0]
|
198 |
+
height = box[3] - box[1]
|
199 |
+
return (height > 2*width)
|
200 |
+
|
201 |
+
def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, percentage_thresh=0.8):
|
202 |
+
"""Maps text boxes to task boxes and groups texts within each task based on proximity."""
|
203 |
+
G = nx.Graph()
|
204 |
+
|
205 |
+
# Map each text box to the nearest task box
|
206 |
+
task_to_texts = {i: [] for i in range(len(task_boxes))}
|
207 |
+
information_texts = [] # texts not inside any task box
|
208 |
+
text_to_task_mapped = [False] * len(text_boxes)
|
209 |
+
|
210 |
+
for idx, text_box in enumerate(text_boxes):
|
211 |
+
mapped = False
|
212 |
+
for jdx, task_box in enumerate(task_boxes):
|
213 |
+
if proportion_inside(text_box, task_box)>iou_threshold:
|
214 |
+
task_to_texts[jdx].append(idx)
|
215 |
+
text_to_task_mapped[idx] = True
|
216 |
+
mapped = True
|
217 |
+
break
|
218 |
+
if not mapped:
|
219 |
+
information_texts.append(idx)
|
220 |
+
|
221 |
+
all_grouped_texts = []
|
222 |
+
sentence_boxes = [] # Store the bounding box for each sentence
|
223 |
+
|
224 |
+
# Process texts for each task
|
225 |
+
for task_texts in task_to_texts.values():
|
226 |
+
G.clear()
|
227 |
+
for i in task_texts:
|
228 |
+
G.add_node(i)
|
229 |
+
for j in task_texts:
|
230 |
+
if i != j and are_close(text_boxes[i], text_boxes[j]) and not is_vertical(text_boxes[i]) and not is_vertical(text_boxes[j]):
|
231 |
+
G.add_edge(i, j)
|
232 |
+
|
233 |
+
groups = list(nx.connected_components(G))
|
234 |
+
for group in groups:
|
235 |
+
group = list(group)
|
236 |
+
lines = {}
|
237 |
+
for idx in group:
|
238 |
+
y_center = (text_boxes[idx][1] + text_boxes[idx][3]) / 2
|
239 |
+
found_line = False
|
240 |
+
for line in lines:
|
241 |
+
if abs(y_center - line) < (text_boxes[idx][3] - text_boxes[idx][1]) / 2:
|
242 |
+
lines[line].append(idx)
|
243 |
+
found_line = True
|
244 |
+
break
|
245 |
+
if not found_line:
|
246 |
+
lines[y_center] = [idx]
|
247 |
+
|
248 |
+
sorted_lines = sorted(lines.keys())
|
249 |
+
grouped_texts = []
|
250 |
+
min_x = min_y = float('inf')
|
251 |
+
max_x = max_y = -float('inf')
|
252 |
+
|
253 |
+
for line in sorted_lines:
|
254 |
+
sorted_indices = sorted(lines[line], key=lambda idx: text_boxes[idx][0])
|
255 |
+
line_text = ' '.join(texts[idx] for idx in sorted_indices)
|
256 |
+
grouped_texts.append(line_text)
|
257 |
+
|
258 |
+
for idx in sorted_indices:
|
259 |
+
box = text_boxes[idx]
|
260 |
+
min_x = min(min_x-5, box[0]-5)
|
261 |
+
min_y = min(min_y-5, box[1]-5)
|
262 |
+
max_x = max(max_x+5, box[2]+5)
|
263 |
+
max_y = max(max_y+5, box[3]+5)
|
264 |
+
|
265 |
+
all_grouped_texts.append(' '.join(grouped_texts))
|
266 |
+
sentence_boxes.append([min_x, min_y, max_x, max_y])
|
267 |
+
|
268 |
+
# Group information texts
|
269 |
+
G.clear()
|
270 |
+
info_sentence_boxes = []
|
271 |
+
|
272 |
+
for i in information_texts:
|
273 |
+
G.add_node(i)
|
274 |
+
for j in information_texts:
|
275 |
+
if i != j and are_close(text_boxes[i], text_boxes[j], percentage_thresh * min_dist) and not is_vertical(text_boxes[i]) and not is_vertical(text_boxes[j]):
|
276 |
+
G.add_edge(i, j)
|
277 |
+
|
278 |
+
info_groups = list(nx.connected_components(G))
|
279 |
+
information_grouped_texts = []
|
280 |
+
for group in info_groups:
|
281 |
+
group = list(group)
|
282 |
+
lines = {}
|
283 |
+
for idx in group:
|
284 |
+
y_center = (text_boxes[idx][1] + text_boxes[idx][3]) / 2
|
285 |
+
found_line = False
|
286 |
+
for line in lines:
|
287 |
+
if abs(y_center - line) < (text_boxes[idx][3] - text_boxes[idx][1]) / 2:
|
288 |
+
lines[line].append(idx)
|
289 |
+
found_line = True
|
290 |
+
break
|
291 |
+
if not found_line:
|
292 |
+
lines[y_center] = [idx]
|
293 |
+
|
294 |
+
sorted_lines = sorted(lines.keys())
|
295 |
+
grouped_texts = []
|
296 |
+
min_x = min_y = float('inf')
|
297 |
+
max_x = max_y = -float('inf')
|
298 |
+
|
299 |
+
for line in sorted_lines:
|
300 |
+
sorted_indices = sorted(lines[line], key=lambda idx: text_boxes[idx][0])
|
301 |
+
line_text = ' '.join(texts[idx] for idx in sorted_indices)
|
302 |
+
grouped_texts.append(line_text)
|
303 |
+
|
304 |
+
for idx in sorted_indices:
|
305 |
+
box = text_boxes[idx]
|
306 |
+
min_x = min(min_x, box[0])
|
307 |
+
min_y = min(min_y, box[1])
|
308 |
+
max_x = max(max_x, box[2])
|
309 |
+
max_y = max(max_y, box[3])
|
310 |
+
|
311 |
+
information_grouped_texts.append(' '.join(grouped_texts))
|
312 |
+
info_sentence_boxes.append([min_x, min_y, max_x, max_y])
|
313 |
+
|
314 |
+
return all_grouped_texts, sentence_boxes, information_grouped_texts, info_sentence_boxes
|
315 |
+
|
316 |
+
|
317 |
+
def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0.6,scale=1.0, iou_threshold=0.5):
|
318 |
+
|
319 |
+
########### REFAIRE CETTE FONCTION ###########
|
320 |
+
#refaire la fonction pour qu'elle prenne en premier les elements qui sont dans les task et ensuite prendre un seuil de distance pour les autres elements
|
321 |
+
#ou sinon faire la distance entre les elements et non pas seulement les tasks
|
322 |
+
|
323 |
+
|
324 |
+
# Example usage
|
325 |
+
boxes = rescale(scale, full_pred['boxes'])
|
326 |
+
|
327 |
+
min_dist = 200
|
328 |
+
labels = full_pred['labels']
|
329 |
+
avoid = [list(class_dict.values()).index('pool'), list(class_dict.values()).index('lane'), list(class_dict.values()).index('sequenceFlow'), list(class_dict.values()).index('messageFlow'), list(class_dict.values()).index('dataAssociation')]
|
330 |
+
for i in range(len(boxes)):
|
331 |
+
box1 = boxes[i]
|
332 |
+
if labels[i] in avoid:
|
333 |
+
continue
|
334 |
+
for j in range(i + 1, len(boxes)):
|
335 |
+
box2 = boxes[j]
|
336 |
+
if labels[j] in avoid:
|
337 |
+
continue
|
338 |
+
dist = min_distance_between_boxes(box1, box2)
|
339 |
+
min_dist = min(min_dist, dist)
|
340 |
+
|
341 |
+
#print("Minimum distance between boxes:", min_dist)
|
342 |
+
|
343 |
+
text_pred[0] = rescale(scale, text_pred[0])
|
344 |
+
task_boxes = [box for i, box in enumerate(boxes) if full_pred['labels'][i] == list(class_dict.values()).index('task')]
|
345 |
+
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_pred[0], text_pred[1], min_dist=min_dist)
|
346 |
+
BPMN_id = set(full_pred['BPMN_id']) # This ensures uniqueness of task names
|
347 |
+
text_mapping = {id: '' for id in BPMN_id}
|
348 |
+
|
349 |
+
|
350 |
+
if print_sentences:
|
351 |
+
for sentence, box in zip(grouped_sentences, sentence_bounding_boxes):
|
352 |
+
print("Task-related Text:", sentence)
|
353 |
+
print("Bounding Box:", box)
|
354 |
+
print("Information Texts:", info_texts)
|
355 |
+
print("Information Bounding Boxes:", info_boxes)
|
356 |
+
|
357 |
+
# Map the grouped sentences to the corresponding task
|
358 |
+
for i in range(len(sentence_bounding_boxes)):
|
359 |
+
for j in range(len(boxes)):
|
360 |
+
if proportion_inside(sentence_bounding_boxes[i], boxes[j])>iou_threshold and full_pred['labels'][j] == list(class_dict.values()).index('task'):
|
361 |
+
text_mapping[full_pred['BPMN_id'][j]]=grouped_sentences[i]
|
362 |
+
|
363 |
+
# Map the grouped sentences to the corresponding pool
|
364 |
+
for i in range(len(info_boxes)):
|
365 |
+
if is_vertical(info_boxes[i]):
|
366 |
+
for j in range(len(boxes)):
|
367 |
+
if proportion_inside(info_boxes[i], boxes[j])>0 and full_pred['labels'][j] == list(class_dict.values()).index('pool'):
|
368 |
+
print("Text:", info_texts[i], "associate with ", full_pred['BPMN_id'][j])
|
369 |
+
bpmn_id = full_pred['BPMN_id'][j]
|
370 |
+
# Append new text or create new entry if not existing
|
371 |
+
if bpmn_id in text_mapping:
|
372 |
+
text_mapping[bpmn_id] += " " + info_texts[i] # Append text with a space in between
|
373 |
+
else:
|
374 |
+
text_mapping[bpmn_id] = info_texts[i]
|
375 |
+
info_texts[i] = '' # Clear the text to avoid re-use
|
376 |
+
|
377 |
+
# Map the grouped sentences to the corresponding object
|
378 |
+
for i in range(len(info_boxes)):
|
379 |
+
if is_vertical(info_boxes[i]):
|
380 |
+
continue # Skip if the text is vertical
|
381 |
+
for j in range(len(boxes)):
|
382 |
+
if info_texts[i] == '':
|
383 |
+
continue # Skip if there's no text
|
384 |
+
if (proportion_inside(info_boxes[i], boxes[j])>0 or are_close(info_boxes[i], boxes[j], threshold=percentage_thresh*min_dist)) and (full_pred['labels'][j] == list(class_dict.values()).index('event')
|
385 |
+
or full_pred['labels'][j] == list(class_dict.values()).index('messageEvent')
|
386 |
+
or full_pred['labels'][j] == list(class_dict.values()).index('timerEvent')
|
387 |
+
or full_pred['labels'][j] == list(class_dict.values()).index('dataObject')) :
|
388 |
+
bpmn_id = full_pred['BPMN_id'][j]
|
389 |
+
# Append new text or create new entry if not existing
|
390 |
+
if bpmn_id in text_mapping:
|
391 |
+
text_mapping[bpmn_id] += " " + info_texts[i] # Append text with a space in between
|
392 |
+
else:
|
393 |
+
text_mapping[bpmn_id] = info_texts[i]
|
394 |
+
info_texts[i] = '' # Clear the text to avoid re-use
|
395 |
+
|
396 |
+
# Map the grouped sentences to the corresponding flow
|
397 |
+
for i in range(len(info_boxes)):
|
398 |
+
if info_texts[i] == '' or is_vertical(info_boxes[i]):
|
399 |
+
continue # Skip if there's no text
|
400 |
+
# Find the closest box within the defined threshold
|
401 |
+
closest_index = find_closest_box(info_boxes[i], boxes, full_pred['labels'], threshold=4*min_dist)
|
402 |
+
if closest_index is not None and (full_pred['labels'][closest_index] == list(class_dict.values()).index('sequenceFlow') or full_pred['labels'][closest_index] == list(class_dict.values()).index('messageFlow')):
|
403 |
+
bpmn_id = full_pred['BPMN_id'][closest_index]
|
404 |
+
# Append new text or create new entry if not existing
|
405 |
+
if bpmn_id in text_mapping:
|
406 |
+
text_mapping[bpmn_id] += " " + info_texts[i] # Append text with a space in between
|
407 |
+
else:
|
408 |
+
text_mapping[bpmn_id] = info_texts[i]
|
409 |
+
info_texts[i] = '' # Clear the text to avoid re-use
|
410 |
+
|
411 |
+
if print_sentences:
|
412 |
+
print("Text Mapping:", text_mapping)
|
413 |
+
print("Information Texts left:", info_texts)
|
414 |
+
|
415 |
+
return text_mapping
|
demo_streamlit.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import streamlit.components.v1 as components
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms import functional as F
|
6 |
+
from PIL import Image, ImageEnhance
|
7 |
+
from htlm_webpage import display_bpmn_xml
|
8 |
+
import gc
|
9 |
+
import psutil
|
10 |
+
|
11 |
+
from OCR import text_prediction, filter_text, mapping_text, rescale
|
12 |
+
from train import prepare_model
|
13 |
+
from utils import draw_annotations, create_loader, class_dict, arrow_dict, object_dict
|
14 |
+
from toXML import calculate_pool_bounds, add_diagram_elements
|
15 |
+
from pathlib import Path
|
16 |
+
from toXML import create_bpmn_object, create_flow_element
|
17 |
+
import xml.etree.ElementTree as ET
|
18 |
+
import numpy as np
|
19 |
+
from display import draw_stream
|
20 |
+
from eval import full_prediction
|
21 |
+
from streamlit_image_comparison import image_comparison
|
22 |
+
from xml.dom import minidom
|
23 |
+
from streamlit_cropper import st_cropper
|
24 |
+
from streamlit_drawable_canvas import st_canvas
|
25 |
+
from utils import find_closest_object
|
26 |
+
from train import get_faster_rcnn_model, get_arrow_model
|
27 |
+
import gdown
|
28 |
+
|
29 |
+
def get_memory_usage():
|
30 |
+
process = psutil.Process()
|
31 |
+
mem_info = process.memory_info()
|
32 |
+
return mem_info.rss / (1024 ** 2) # Return memory usage in MB
|
33 |
+
|
34 |
+
def clear_memory():
|
35 |
+
st.session_state.clear()
|
36 |
+
gc.collect()
|
37 |
+
|
38 |
+
# Function to read XML content from a file
|
39 |
+
def read_xml_file(filepath):
|
40 |
+
""" Read XML content from a file """
|
41 |
+
with open(filepath, 'r', encoding='utf-8') as file:
|
42 |
+
return file.read()
|
43 |
+
|
44 |
+
# Function to modify bounding box positions based on the given sizes
|
45 |
+
def modif_box_pos(pred, size):
|
46 |
+
for i, (x1, y1, x2, y2) in enumerate(pred['boxes']):
|
47 |
+
center = [(x1 + x2) / 2, (y1 + y2) / 2]
|
48 |
+
label = class_dict[pred['labels'][i]]
|
49 |
+
if label in size:
|
50 |
+
pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
|
51 |
+
return pred
|
52 |
+
|
53 |
+
# Function to create a BPMN XML file from prediction results
|
54 |
+
def create_XML(full_pred, text_mapping, scale):
|
55 |
+
namespaces = {
|
56 |
+
'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
|
57 |
+
'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
|
58 |
+
'di': 'http://www.omg.org/spec/DD/20100524/DI',
|
59 |
+
'dc': 'http://www.omg.org/spec/DD/20100524/DC',
|
60 |
+
'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
|
61 |
+
}
|
62 |
+
|
63 |
+
size_elements = {
|
64 |
+
'start': (54, 54),
|
65 |
+
'task': (150, 120),
|
66 |
+
'message': (54, 54),
|
67 |
+
'messageEvent': (54, 54),
|
68 |
+
'end': (54, 54),
|
69 |
+
'exclusiveGateway': (75, 75),
|
70 |
+
'event': (54, 54),
|
71 |
+
'parallelGateway': (75, 75),
|
72 |
+
'sequenceFlow': (225, 15),
|
73 |
+
'pool': (375, 150),
|
74 |
+
'lane': (300, 150),
|
75 |
+
'dataObject': (60, 90),
|
76 |
+
'dataStore': (90, 90),
|
77 |
+
'subProcess': (180, 135),
|
78 |
+
'eventBasedGateway': (75, 75),
|
79 |
+
'timerEvent': (60, 60),
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
definitions = ET.Element('bpmn:definitions', {
|
84 |
+
'xmlns:xsi': namespaces['xsi'],
|
85 |
+
'xmlns:bpmn': namespaces['bpmn'],
|
86 |
+
'xmlns:bpmndi': namespaces['bpmndi'],
|
87 |
+
'xmlns:di': namespaces['di'],
|
88 |
+
'xmlns:dc': namespaces['dc'],
|
89 |
+
'targetNamespace': "http://example.bpmn.com",
|
90 |
+
'id': "simpleExample"
|
91 |
+
})
|
92 |
+
|
93 |
+
# Create BPMN collaboration element
|
94 |
+
collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1')
|
95 |
+
|
96 |
+
# Create BPMN process elements
|
97 |
+
process = []
|
98 |
+
for idx in range(len(full_pred['pool_dict'].items())):
|
99 |
+
process_id = f'process_{idx+1}'
|
100 |
+
process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]]))
|
101 |
+
|
102 |
+
bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
|
103 |
+
bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
|
104 |
+
|
105 |
+
full_pred['boxes'] = rescale(scale, full_pred['boxes'])
|
106 |
+
|
107 |
+
# Add diagram elements for each pool
|
108 |
+
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
109 |
+
pool_id = f'participant_{idx+1}'
|
110 |
+
pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])
|
111 |
+
|
112 |
+
# Calculate the bounding box for the pool
|
113 |
+
if len(keep_elements) == 0:
|
114 |
+
min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index]
|
115 |
+
pool_width = max_x - min_x
|
116 |
+
pool_height = max_y - min_y
|
117 |
+
else:
|
118 |
+
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements)
|
119 |
+
pool_width = max_x - min_x + 100 # Adding padding
|
120 |
+
pool_height = max_y - min_y + 100 # Adding padding
|
121 |
+
|
122 |
+
add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height)
|
123 |
+
|
124 |
+
# Create BPMN elements for each pool
|
125 |
+
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
126 |
+
create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
|
127 |
+
|
128 |
+
# Create message flow elements
|
129 |
+
message_flows = [i for i, label in enumerate(full_pred['labels']) if class_dict[label] == 'messageFlow']
|
130 |
+
for idx in message_flows:
|
131 |
+
create_flow_element(bpmnplane, text_mapping, idx, size_elements, full_pred, collaboration, message=True)
|
132 |
+
|
133 |
+
# Create sequence flow elements
|
134 |
+
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
135 |
+
for i in keep_elements:
|
136 |
+
if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
|
137 |
+
create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
|
138 |
+
|
139 |
+
# Generate pretty XML string
|
140 |
+
tree = ET.ElementTree(definitions)
|
141 |
+
rough_string = ET.tostring(definitions, 'utf-8')
|
142 |
+
reparsed = minidom.parseString(rough_string)
|
143 |
+
pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
|
144 |
+
|
145 |
+
full_pred['boxes'] = rescale(1/scale, full_pred['boxes'])
|
146 |
+
|
147 |
+
return pretty_xml_as_string
|
148 |
+
|
149 |
+
|
150 |
+
# Function to load the models only once and use session state to keep track of it
|
151 |
+
def load_models():
|
152 |
+
with st.spinner('Loading model...'):
|
153 |
+
model_object = get_faster_rcnn_model(len(object_dict))
|
154 |
+
model_arrow = get_arrow_model(len(arrow_dict),2)
|
155 |
+
|
156 |
+
url_arrow = 'https://drive.google.com/uc?id=1xwfvo7BgDWz-1jAiJC1DCF0Wp8YlFNWt'
|
157 |
+
url_object = 'https://drive.google.com/uc?id=1GiM8xOXG6M6r8J9HTOeMJz9NKu7iumZi'
|
158 |
+
|
159 |
+
# Define paths to save models
|
160 |
+
output_arrow = 'model_arrow.pth'
|
161 |
+
output_object = 'model_object.pth'
|
162 |
+
|
163 |
+
# Download models using gdown
|
164 |
+
if not Path(output_arrow).exists():
|
165 |
+
# Download models using gdown
|
166 |
+
gdown.download(url_arrow, output_arrow, quiet=False)
|
167 |
+
else:
|
168 |
+
print('Model arrow downloaded from local')
|
169 |
+
if not Path(output_object).exists():
|
170 |
+
gdown.download(url_object, output_object, quiet=False)
|
171 |
+
else:
|
172 |
+
print('Model object downloaded from local')
|
173 |
+
|
174 |
+
# Load models
|
175 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
176 |
+
model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
|
177 |
+
model_object.load_state_dict(torch.load(output_object, map_location=device))
|
178 |
+
st.session_state.model_loaded = True
|
179 |
+
st.session_state.model_arrow = model_arrow
|
180 |
+
st.session_state.model_object = model_object
|
181 |
+
|
182 |
+
# Function to prepare the image for processing
|
183 |
+
def prepare_image(image, pad=True, new_size=(1333, 1333)):
|
184 |
+
original_size = image.size
|
185 |
+
# Calculate scale to fit the new size while maintaining aspect ratio
|
186 |
+
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
|
187 |
+
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
|
188 |
+
# Resize image to new scaled size
|
189 |
+
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
|
190 |
+
|
191 |
+
if pad:
|
192 |
+
enhancer = ImageEnhance.Brightness(image)
|
193 |
+
image = enhancer.enhance(1.5) # Adjust the brightness if necessary
|
194 |
+
# Pad the resized image to make it exactly the desired size
|
195 |
+
padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
|
196 |
+
image = F.pad(image, padding, fill=200, padding_mode='edge')
|
197 |
+
|
198 |
+
return new_scaled_size, image
|
199 |
+
|
200 |
+
# Function to display various options for image annotation
|
201 |
+
def display_options(image, score_threshold):
|
202 |
+
col1, col2, col3, col4, col5 = st.columns(5)
|
203 |
+
with col1:
|
204 |
+
write_class = st.toggle("Write Class", value=True)
|
205 |
+
draw_keypoints = st.toggle("Draw Keypoints", value=True)
|
206 |
+
draw_boxes = st.toggle("Draw Boxes", value=True)
|
207 |
+
with col2:
|
208 |
+
draw_text = st.toggle("Draw Text", value=False)
|
209 |
+
write_text = st.toggle("Write Text", value=False)
|
210 |
+
draw_links = st.toggle("Draw Links", value=False)
|
211 |
+
with col3:
|
212 |
+
write_score = st.toggle("Write Score", value=True)
|
213 |
+
write_idx = st.toggle("Write Index", value=False)
|
214 |
+
with col4:
|
215 |
+
# Define options for the dropdown menu
|
216 |
+
dropdown_options = [list(class_dict.values())[i] for i in range(len(class_dict))]
|
217 |
+
dropdown_options[0] = 'all'
|
218 |
+
selected_option = st.selectbox("Show class", dropdown_options)
|
219 |
+
|
220 |
+
# Draw the annotated image with selected options
|
221 |
+
annotated_image = draw_stream(
|
222 |
+
np.array(image), prediction=st.session_state.prediction, text_predictions=st.session_state.text_pred,
|
223 |
+
draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
|
224 |
+
write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_print=selected_option,
|
225 |
+
score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
|
226 |
+
)
|
227 |
+
|
228 |
+
# Display the original and annotated images side by side
|
229 |
+
image_comparison(
|
230 |
+
img1=annotated_image,
|
231 |
+
img2=image,
|
232 |
+
label1="Annotated Image",
|
233 |
+
label2="Original Image",
|
234 |
+
starting_position=99,
|
235 |
+
width=1000,
|
236 |
+
)
|
237 |
+
|
238 |
+
# Function to perform inference on the uploaded image using the loaded models
|
239 |
+
def perform_inference(model_object, model_arrow, image, score_threshold):
|
240 |
+
_, uploaded_image = prepare_image(image, pad=False)
|
241 |
+
|
242 |
+
img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
|
243 |
+
|
244 |
+
# Display original image
|
245 |
+
if 'image_placeholder' not in st.session_state:
|
246 |
+
image_placeholder = st.empty() # Create an empty placeholder
|
247 |
+
image_placeholder.image(uploaded_image, caption='Original Image', width=1000)
|
248 |
+
|
249 |
+
# Prediction
|
250 |
+
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=0.5)
|
251 |
+
|
252 |
+
# Perform OCR on the uploaded image
|
253 |
+
ocr_results = text_prediction(uploaded_image)
|
254 |
+
|
255 |
+
# Filter and map OCR results to prediction results
|
256 |
+
st.session_state.text_pred = filter_text(ocr_results, threshold=0.5)
|
257 |
+
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=0.5)
|
258 |
+
|
259 |
+
# Remove the original image display
|
260 |
+
image_placeholder.empty()
|
261 |
+
|
262 |
+
# Force garbage collection
|
263 |
+
gc.collect()
|
264 |
+
|
265 |
+
@st.cache_data
|
266 |
+
def get_image(uploaded_file):
|
267 |
+
return Image.open(uploaded_file).convert('RGB')
|
268 |
+
|
269 |
+
def main():
|
270 |
+
st.set_page_config(layout="wide")
|
271 |
+
st.title("BPMN model recognition demo")
|
272 |
+
|
273 |
+
# Display current memory usage
|
274 |
+
memory_usage = get_memory_usage()
|
275 |
+
print(f"Current memory usage: {memory_usage:.2f} MB")
|
276 |
+
|
277 |
+
# Initialize the session state for storing pool bounding boxes
|
278 |
+
if 'pool_bboxes' not in st.session_state:
|
279 |
+
st.session_state.pool_bboxes = []
|
280 |
+
|
281 |
+
# Load the models using the defined function
|
282 |
+
if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
|
283 |
+
clear_memory()
|
284 |
+
load_models()
|
285 |
+
|
286 |
+
model_arrow = st.session_state.model_arrow
|
287 |
+
model_object = st.session_state.model_object
|
288 |
+
|
289 |
+
#Create the layout for the app
|
290 |
+
col1, col2 = st.columns(2)
|
291 |
+
with col1:
|
292 |
+
# Create a file uploader for the user to upload an image
|
293 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
294 |
+
|
295 |
+
# Display the uploaded image if the user has uploaded an image
|
296 |
+
if uploaded_file is not None:
|
297 |
+
original_image = get_image(uploaded_file)
|
298 |
+
col1, col2 = st.columns(2)
|
299 |
+
|
300 |
+
# Create a cropper to allow the user to crop the image and display the cropped image
|
301 |
+
with col1:
|
302 |
+
cropped_image = st_cropper(original_image, realtime_update=True, box_color='#0000FF', should_resize_image=True, default_coords=(30, original_image.size[0]-30, 30, original_image.size[1]-30))
|
303 |
+
with col2:
|
304 |
+
st.image(cropped_image, caption="Cropped Image", use_column_width=False, width=500)
|
305 |
+
|
306 |
+
# Display the options for the user to set the score threshold and scale
|
307 |
+
if cropped_image is not None:
|
308 |
+
col1, col2, col3 = st.columns(3)
|
309 |
+
with col1:
|
310 |
+
score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
|
311 |
+
with col2:
|
312 |
+
st.session_state.scale = st.slider("Set scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
|
313 |
+
|
314 |
+
# Launch the prediction when the user clicks the button
|
315 |
+
if st.button("Launch Prediction"):
|
316 |
+
st.session_state.crop_image = cropped_image
|
317 |
+
with st.spinner('Processing...'):
|
318 |
+
perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
|
319 |
+
st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
|
320 |
+
|
321 |
+
print('Detection completed!')
|
322 |
+
|
323 |
+
|
324 |
+
# If the prediction has been made and the user has uploaded an image, display the options for the user to annotate the image
|
325 |
+
if 'prediction' in st.session_state and uploaded_file is not None:
|
326 |
+
st.success('Detection completed!')
|
327 |
+
display_options(st.session_state.crop_image, score_threshold)
|
328 |
+
|
329 |
+
#if st.session_state.prediction_up==True:
|
330 |
+
st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale)
|
331 |
+
|
332 |
+
display_bpmn_xml(st.session_state.bpmn_xml)
|
333 |
+
|
334 |
+
# Force garbage collection after display
|
335 |
+
gc.collect()
|
336 |
+
|
337 |
+
if __name__ == "__main__":
|
338 |
+
print('Starting the app...')
|
339 |
+
main()
|
display.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import draw_annotations, create_loader, class_dict, resize_boxes, resize_keypoints, find_other_keypoint
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from OCR import group_texts
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def draw_stream(image,
|
12 |
+
prediction=None,
|
13 |
+
text_predictions=None,
|
14 |
+
class_dict=class_dict,
|
15 |
+
draw_keypoints=False,
|
16 |
+
draw_boxes=False,
|
17 |
+
draw_text=False,
|
18 |
+
draw_links=False,
|
19 |
+
draw_twins=False,
|
20 |
+
draw_grouped_text=False,
|
21 |
+
write_class=False,
|
22 |
+
write_score=False,
|
23 |
+
write_text=False,
|
24 |
+
score_threshold=0.4,
|
25 |
+
write_idx=False,
|
26 |
+
keypoints_correction=False,
|
27 |
+
new_size=(1333, 1333),
|
28 |
+
only_print=None,
|
29 |
+
axis=False,
|
30 |
+
return_image=False,
|
31 |
+
resize=False):
|
32 |
+
"""
|
33 |
+
Draws annotations on images including bounding boxes, keypoints, links, and text.
|
34 |
+
|
35 |
+
Parameters:
|
36 |
+
- image (np.array): The image on which annotations will be drawn.
|
37 |
+
- target (dict): Ground truth data containing boxes, labels, etc.
|
38 |
+
- prediction (dict): Prediction data from a model.
|
39 |
+
- full_prediction (dict): Additional detailed prediction data, potentially including relationships.
|
40 |
+
- text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
|
41 |
+
- class_dict (dict): Mapping from class IDs to class names.
|
42 |
+
- draw_keypoints (bool): Flag to draw keypoints.
|
43 |
+
- draw_boxes (bool): Flag to draw bounding boxes.
|
44 |
+
- draw_text (bool): Flag to draw text annotations.
|
45 |
+
- draw_links (bool): Flag to draw links between annotations.
|
46 |
+
- draw_twins (bool): Flag to draw twins keypoints.
|
47 |
+
- write_class (bool): Flag to write class names near the annotations.
|
48 |
+
- write_score (bool): Flag to write scores near the annotations.
|
49 |
+
- write_text (bool): Flag to write OCR recognized text.
|
50 |
+
- score_threshold (float): Threshold for scores above which annotations will be drawn.
|
51 |
+
- only_print (str): Specific class name to filter annotations by.
|
52 |
+
- resize (bool): Whether to resize annotations to fit the image size.
|
53 |
+
"""
|
54 |
+
|
55 |
+
# Convert image to RGB (if not already in that format)
|
56 |
+
if prediction is None:
|
57 |
+
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
58 |
+
|
59 |
+
|
60 |
+
image_copy = image.copy()
|
61 |
+
scale = max(image.shape[0], image.shape[1]) / 1000
|
62 |
+
|
63 |
+
original_size = (image.shape[0], image.shape[1])
|
64 |
+
# Calculate scale to fit the new size while maintaining aspect ratio
|
65 |
+
scale_ = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
|
66 |
+
new_scaled_size = (int(original_size[0] * scale_), int(original_size[1] * scale_))
|
67 |
+
|
68 |
+
for i in range(len(prediction['boxes'])):
|
69 |
+
box = prediction['boxes'][i]
|
70 |
+
x1, y1, x2, y2 = box
|
71 |
+
if resize:
|
72 |
+
x1, y1, x2, y2 = resize_boxes(np.array([box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
73 |
+
score = prediction['scores'][i]
|
74 |
+
if score < score_threshold:
|
75 |
+
continue
|
76 |
+
if draw_boxes:
|
77 |
+
if only_print is not None and only_print != 'all':
|
78 |
+
if prediction['labels'][i] != list(class_dict.values()).index(only_print):
|
79 |
+
continue
|
80 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0), int(2*scale))
|
81 |
+
if write_score:
|
82 |
+
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
|
83 |
+
if write_idx:
|
84 |
+
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
|
85 |
+
|
86 |
+
if write_class and 'labels' in prediction:
|
87 |
+
class_id = prediction['labels'][i]
|
88 |
+
cv2.putText(image_copy, class_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
|
89 |
+
|
90 |
+
|
91 |
+
# Draw keypoints if available
|
92 |
+
if draw_keypoints and 'keypoints' in prediction:
|
93 |
+
for i in range(len(prediction['keypoints'])):
|
94 |
+
kp = prediction['keypoints'][i]
|
95 |
+
for j in range(kp.shape[0]):
|
96 |
+
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
|
97 |
+
continue
|
98 |
+
|
99 |
+
score = prediction['scores'][i]
|
100 |
+
if score < score_threshold:
|
101 |
+
continue
|
102 |
+
x,y, v = np.array(kp[j])
|
103 |
+
x, y, v = resize_keypoints(np.array([kp[j]]), (new_scaled_size[1],new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
104 |
+
if j == 0:
|
105 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
|
106 |
+
else:
|
107 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
|
108 |
+
|
109 |
+
# Draw text predictions if available
|
110 |
+
if (draw_text or write_text) and text_predictions is not None:
|
111 |
+
for i in range(len(text_predictions[0])):
|
112 |
+
x1, y1, x2, y2 = text_predictions[0][i]
|
113 |
+
text = text_predictions[1][i]
|
114 |
+
if resize:
|
115 |
+
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
116 |
+
if draw_text:
|
117 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
|
118 |
+
if write_text:
|
119 |
+
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
|
120 |
+
|
121 |
+
|
122 |
+
'''Draws links between objects based on the full prediction data.'''
|
123 |
+
#check if keypoints detected are the same
|
124 |
+
if draw_twins and prediction is not None:
|
125 |
+
# Pre-calculate indices for performance
|
126 |
+
circle_color = (0, 255, 0) # Green color for the circle
|
127 |
+
circle_radius = int(10 * scale) # Circle radius scaled by image scale
|
128 |
+
|
129 |
+
for idx, (key1, key2) in enumerate(prediction['keypoints']):
|
130 |
+
if prediction['labels'][idx] not in [list(class_dict.values()).index('sequenceFlow'),
|
131 |
+
list(class_dict.values()).index('messageFlow'),
|
132 |
+
list(class_dict.values()).index('dataAssociation')]:
|
133 |
+
continue
|
134 |
+
# Calculate the Euclidean distance between the two keypoints
|
135 |
+
distance = np.linalg.norm(key1[:2] - key2[:2])
|
136 |
+
if distance < 10:
|
137 |
+
x_new,y_new, x,y = find_other_keypoint(idx,prediction)
|
138 |
+
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
|
139 |
+
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
|
140 |
+
|
141 |
+
# Draw links between objects
|
142 |
+
if draw_links==True and prediction is not None:
|
143 |
+
for i, (start_idx, end_idx) in enumerate(prediction['links']):
|
144 |
+
if start_idx is None or end_idx is None:
|
145 |
+
continue
|
146 |
+
start_box = prediction['boxes'][start_idx]
|
147 |
+
start_box = resize_boxes(np.array([start_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
148 |
+
end_box = prediction['boxes'][end_idx]
|
149 |
+
end_box = resize_boxes(np.array([end_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
150 |
+
current_box = prediction['boxes'][i]
|
151 |
+
current_box = resize_boxes(np.array([current_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
152 |
+
# Calculate the center of each bounding box
|
153 |
+
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
|
154 |
+
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
|
155 |
+
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
|
156 |
+
# Draw a line between the centers of the connected objects
|
157 |
+
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
|
158 |
+
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
|
159 |
+
|
160 |
+
|
161 |
+
if draw_grouped_text and prediction is not None:
|
162 |
+
task_boxes = task_boxes = [box for i, box in enumerate(prediction['boxes']) if prediction['labels'][i] == list(class_dict.values()).index('task')]
|
163 |
+
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_predictions[0], text_predictions[1], percentage_thresh=1)
|
164 |
+
for i in range(len(info_boxes)):
|
165 |
+
x1, y1, x2, y2 = info_boxes[i]
|
166 |
+
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
167 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
|
168 |
+
for i in range(len(sentence_bounding_boxes)):
|
169 |
+
x1,y1,x2,y2 = sentence_bounding_boxes[i]
|
170 |
+
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
|
171 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
|
172 |
+
|
173 |
+
if return_image:
|
174 |
+
return image_copy
|
175 |
+
else:
|
176 |
+
# Display the image
|
177 |
+
plt.figure(figsize=(12, 12))
|
178 |
+
plt.imshow(image_copy)
|
179 |
+
if axis==False:
|
180 |
+
plt.axis('off')
|
181 |
+
plt.show()
|
eval.py
ADDED
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
|
4 |
+
from tqdm import tqdm
|
5 |
+
from toXML import create_BPMN_id
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
|
11 |
+
idxs = np.argsort(scores) # Sort the boxes according to their scores in ascending order
|
12 |
+
selected_boxes = []
|
13 |
+
|
14 |
+
while len(idxs) > 0:
|
15 |
+
last = len(idxs) - 1
|
16 |
+
i = idxs[last]
|
17 |
+
|
18 |
+
# Skip if the label is a lane
|
19 |
+
if labels is not None and class_dict[labels[i]] == 'lane':
|
20 |
+
selected_boxes.append(i)
|
21 |
+
idxs = np.delete(idxs, last)
|
22 |
+
continue
|
23 |
+
|
24 |
+
selected_boxes.append(i)
|
25 |
+
|
26 |
+
# Find the intersection of the box with the rest
|
27 |
+
suppress = [last]
|
28 |
+
for pos in range(0, last):
|
29 |
+
j = idxs[pos]
|
30 |
+
if iou(boxes[i], boxes[j]) > iou_threshold:
|
31 |
+
suppress.append(pos)
|
32 |
+
|
33 |
+
idxs = np.delete(idxs, suppress)
|
34 |
+
|
35 |
+
# Return only the boxes that were selected
|
36 |
+
return selected_boxes
|
37 |
+
|
38 |
+
|
39 |
+
def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distance_treshold=15):
|
40 |
+
for idx, (key1, key2) in enumerate(keypoints):
|
41 |
+
if labels[idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
42 |
+
list(model_dict.values()).index('messageFlow'),
|
43 |
+
list(model_dict.values()).index('dataAssociation')]:
|
44 |
+
continue
|
45 |
+
# Calculate the Euclidean distance between the two keypoints
|
46 |
+
distance = np.linalg.norm(key1[:2] - key2[:2])
|
47 |
+
if distance < distance_treshold:
|
48 |
+
print('Key modified for index:', idx)
|
49 |
+
x_new,y_new, x,y = find_other_keypoint(idx, keypoints, boxes)
|
50 |
+
keypoints[idx][0][:2] = [x_new,y_new]
|
51 |
+
keypoints[idx][1][:2] = [x,y]
|
52 |
+
|
53 |
+
return keypoints
|
54 |
+
|
55 |
+
|
56 |
+
def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
|
57 |
+
model.eval()
|
58 |
+
with torch.no_grad():
|
59 |
+
image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
|
60 |
+
predictions = model(image_tensor)
|
61 |
+
|
62 |
+
boxes = predictions[0]['boxes'].cpu().numpy()
|
63 |
+
labels = predictions[0]['labels'].cpu().numpy()
|
64 |
+
scores = predictions[0]['scores'].cpu().numpy()
|
65 |
+
|
66 |
+
idx = np.where(scores > score_threshold)[0]
|
67 |
+
boxes = boxes[idx]
|
68 |
+
scores = scores[idx]
|
69 |
+
labels = labels[idx]
|
70 |
+
|
71 |
+
selected_boxes = non_maximum_suppression(boxes, scores, labels=labels, iou_threshold=iou_threshold)
|
72 |
+
|
73 |
+
#find orientation of the task by checking the size of all the boxes and delete the one that are not in the same orientation
|
74 |
+
vertical = 0
|
75 |
+
for i in range(len(labels)):
|
76 |
+
if labels[i] != list(object_dict.values()).index('task'):
|
77 |
+
continue
|
78 |
+
if boxes[i][2]-boxes[i][0] < boxes[i][3]-boxes[i][1]:
|
79 |
+
vertical += 1
|
80 |
+
horizontal = len(labels) - vertical
|
81 |
+
for i in range(len(labels)):
|
82 |
+
if labels[i] != list(object_dict.values()).index('task'):
|
83 |
+
continue
|
84 |
+
|
85 |
+
if vertical < horizontal:
|
86 |
+
if boxes[i][2]-boxes[i][0] < boxes[i][3]-boxes[i][1]:
|
87 |
+
#find the element in the list and remove it
|
88 |
+
if i in selected_boxes:
|
89 |
+
selected_boxes.remove(i)
|
90 |
+
elif vertical > horizontal:
|
91 |
+
if boxes[i][2]-boxes[i][0] > boxes[i][3]-boxes[i][1]:
|
92 |
+
#find the element in the list and remove it
|
93 |
+
if i in selected_boxes:
|
94 |
+
selected_boxes.remove(i)
|
95 |
+
else:
|
96 |
+
pass
|
97 |
+
|
98 |
+
boxes = boxes[selected_boxes]
|
99 |
+
scores = scores[selected_boxes]
|
100 |
+
labels = labels[selected_boxes]
|
101 |
+
|
102 |
+
prediction = {
|
103 |
+
'boxes': boxes,
|
104 |
+
'scores': scores,
|
105 |
+
'labels': labels,
|
106 |
+
}
|
107 |
+
|
108 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
109 |
+
image = (image * 255).astype(np.uint8)
|
110 |
+
|
111 |
+
return image, prediction
|
112 |
+
|
113 |
+
|
114 |
+
def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, distance_treshold=15):
|
115 |
+
model.eval()
|
116 |
+
with torch.no_grad():
|
117 |
+
image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
|
118 |
+
predictions = model(image_tensor)
|
119 |
+
|
120 |
+
boxes = predictions[0]['boxes'].cpu().numpy()
|
121 |
+
labels = predictions[0]['labels'].cpu().numpy() + (len(object_dict) - 1)
|
122 |
+
scores = predictions[0]['scores'].cpu().numpy()
|
123 |
+
keypoints = predictions[0]['keypoints'].cpu().numpy()
|
124 |
+
|
125 |
+
idx = np.where(scores > score_threshold)[0]
|
126 |
+
boxes = boxes[idx]
|
127 |
+
scores = scores[idx]
|
128 |
+
labels = labels[idx]
|
129 |
+
keypoints = keypoints[idx]
|
130 |
+
|
131 |
+
selected_boxes = non_maximum_suppression(boxes, scores, iou_threshold=iou_threshold)
|
132 |
+
boxes = boxes[selected_boxes]
|
133 |
+
scores = scores[selected_boxes]
|
134 |
+
labels = labels[selected_boxes]
|
135 |
+
keypoints = keypoints[selected_boxes]
|
136 |
+
|
137 |
+
keypoints = keypoint_correction(keypoints, boxes, labels, class_dict, distance_treshold=distance_treshold)
|
138 |
+
|
139 |
+
prediction = {
|
140 |
+
'boxes': boxes,
|
141 |
+
'scores': scores,
|
142 |
+
'labels': labels,
|
143 |
+
'keypoints': keypoints,
|
144 |
+
}
|
145 |
+
|
146 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
147 |
+
image = (image * 255).astype(np.uint8)
|
148 |
+
|
149 |
+
return image, prediction
|
150 |
+
|
151 |
+
def mix_predictions(objects_pred, arrow_pred):
|
152 |
+
# Initialize the list of lists for keypoints
|
153 |
+
object_keypoints = []
|
154 |
+
|
155 |
+
# Number of boxes
|
156 |
+
num_boxes = len(objects_pred['boxes'])
|
157 |
+
|
158 |
+
# Iterate over the number of boxes
|
159 |
+
for _ in range(num_boxes):
|
160 |
+
# Each box has 2 keypoints, both initialized to [0, 0, 0]
|
161 |
+
keypoints = [[0, 0, 0], [0, 0, 0]]
|
162 |
+
object_keypoints.append(keypoints)
|
163 |
+
|
164 |
+
#concatenate the two predictions
|
165 |
+
boxes = np.concatenate((objects_pred['boxes'], arrow_pred['boxes']))
|
166 |
+
labels = np.concatenate((objects_pred['labels'], arrow_pred['labels']))
|
167 |
+
scores = np.concatenate((objects_pred['scores'], arrow_pred['scores']))
|
168 |
+
keypoints = np.concatenate((object_keypoints, arrow_pred['keypoints']))
|
169 |
+
|
170 |
+
return boxes, labels, scores, keypoints
|
171 |
+
|
172 |
+
def regroup_elements_by_pool(boxes, labels, class_dict):
|
173 |
+
"""
|
174 |
+
Regroups elements by the pool they belong to, and creates a single new pool for elements that are not in any existing pool.
|
175 |
+
|
176 |
+
Parameters:
|
177 |
+
- boxes (list): List of bounding boxes.
|
178 |
+
- labels (list): List of labels corresponding to each bounding box.
|
179 |
+
- class_dict (dict): Dictionary mapping class indices to class names.
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
- dict: A dictionary where each key is a pool's index and the value is a list of elements within that pool.
|
183 |
+
"""
|
184 |
+
# Initialize a dictionary to hold the elements in each pool
|
185 |
+
pool_dict = {}
|
186 |
+
|
187 |
+
# Identify the bounding boxes of the pools
|
188 |
+
pool_indices = [i for i, label in enumerate(labels) if (class_dict[label.item()] == 'pool')]
|
189 |
+
pool_boxes = [boxes[i] for i in pool_indices]
|
190 |
+
|
191 |
+
if not pool_indices:
|
192 |
+
# If no pools or lanes are detected, create a single pool with all elements
|
193 |
+
labels = np.append(labels, list(class_dict.values()).index('pool'))
|
194 |
+
pool_dict[len(labels)-1] = list(range(len(boxes)))
|
195 |
+
else:
|
196 |
+
# Initialize each pool index with an empty list
|
197 |
+
for pool_index in pool_indices:
|
198 |
+
pool_dict[pool_index] = []
|
199 |
+
|
200 |
+
# Initialize a list for elements not in any pool
|
201 |
+
elements_not_in_pool = []
|
202 |
+
|
203 |
+
# Iterate over all elements
|
204 |
+
for i, box in enumerate(boxes):
|
205 |
+
if i in pool_indices or class_dict[labels[i]] == 'messageFlow':
|
206 |
+
continue # Skip pool boxes themselves and messageFlow elements
|
207 |
+
assigned_to_pool = False
|
208 |
+
for j, pool_box in enumerate(pool_boxes):
|
209 |
+
# Check if the element is within the pool's bounding box
|
210 |
+
if (box[0] >= pool_box[0] and box[1] >= pool_box[1] and
|
211 |
+
box[2] <= pool_box[2] and box[3] <= pool_box[3]):
|
212 |
+
pool_index = pool_indices[j]
|
213 |
+
pool_dict[pool_index].append(i)
|
214 |
+
assigned_to_pool = True
|
215 |
+
break
|
216 |
+
if not assigned_to_pool:
|
217 |
+
if class_dict[labels[i]] != 'messageFlow' and class_dict[labels[i]] != 'lane':
|
218 |
+
elements_not_in_pool.append(i)
|
219 |
+
|
220 |
+
if elements_not_in_pool:
|
221 |
+
new_pool_index = max(pool_dict.keys()) + 1
|
222 |
+
labels = np.append(labels, list(class_dict.values()).index('pool'))
|
223 |
+
pool_dict[new_pool_index] = elements_not_in_pool
|
224 |
+
|
225 |
+
# Separate empty pools
|
226 |
+
non_empty_pools = {k: v for k, v in pool_dict.items() if v}
|
227 |
+
empty_pools = {k: v for k, v in pool_dict.items() if not v}
|
228 |
+
|
229 |
+
# Merge non-empty pools followed by empty pools
|
230 |
+
pool_dict = {**non_empty_pools, **empty_pools}
|
231 |
+
|
232 |
+
return pool_dict, labels
|
233 |
+
|
234 |
+
|
235 |
+
def create_links(keypoints, boxes, labels, class_dict):
|
236 |
+
best_points = []
|
237 |
+
links = []
|
238 |
+
for i in range(len(labels)):
|
239 |
+
if labels[i]==list(class_dict.values()).index('sequenceFlow') or labels[i]==list(class_dict.values()).index('messageFlow'):
|
240 |
+
closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
|
241 |
+
closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
|
242 |
+
if closest1 is not None and closest2 is not None:
|
243 |
+
best_points.append([point_start, point_end])
|
244 |
+
links.append([closest1, closest2])
|
245 |
+
else:
|
246 |
+
best_points.append([None,None])
|
247 |
+
links.append([None,None])
|
248 |
+
|
249 |
+
for i in range(len(labels)):
|
250 |
+
if labels[i]==list(class_dict.values()).index('dataAssociation'):
|
251 |
+
closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
|
252 |
+
closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
|
253 |
+
if closest1 is not None and closest2 is not None:
|
254 |
+
best_points[i] = ([point_start, point_end])
|
255 |
+
links[i] = ([closest1, closest2])
|
256 |
+
|
257 |
+
return links, best_points
|
258 |
+
|
259 |
+
def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
|
260 |
+
|
261 |
+
for pool_index, elements in pool_dict.items():
|
262 |
+
print(f"Pool {pool_index} contains elements: {elements}")
|
263 |
+
#check if each link is in the same pool
|
264 |
+
for i in range(len(flow_links)):
|
265 |
+
if labels[i] == list(class_dict.values()).index('sequenceFlow'):
|
266 |
+
id1, id2 = flow_links[i]
|
267 |
+
if (id1 and id2) is not None:
|
268 |
+
if id1 in elements and id2 in elements:
|
269 |
+
continue
|
270 |
+
elif id1 not in elements and id2 not in elements:
|
271 |
+
continue
|
272 |
+
else:
|
273 |
+
print('change the link from sequenceFlow to messageFlow')
|
274 |
+
labels[i]=list(class_dict.values()).index('messageFlow')
|
275 |
+
|
276 |
+
return labels, flow_links
|
277 |
+
|
278 |
+
|
279 |
+
def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_dict):
|
280 |
+
|
281 |
+
#delete pool that are have only messageFlow on it
|
282 |
+
delete_pool = []
|
283 |
+
for pool_index, elements in pool_dict.items():
|
284 |
+
if all([labels[i] == list(class_dict.values()).index('messageFlow') for i in elements]):
|
285 |
+
if len(elements) > 0:
|
286 |
+
delete_pool.append(pool_dict[pool_index])
|
287 |
+
print(f"Pool {pool_index} contains only messageFlow elements, deleting it")
|
288 |
+
|
289 |
+
#sort index
|
290 |
+
delete_pool = sorted(delete_pool, reverse=True)
|
291 |
+
for pool in delete_pool:
|
292 |
+
index = list(pool_dict.keys())[list(pool_dict.values()).index(pool)]
|
293 |
+
del pool_dict[index]
|
294 |
+
|
295 |
+
|
296 |
+
delete_elements = []
|
297 |
+
# Check if there is an arrow that has the same links
|
298 |
+
for i in range(len(labels)):
|
299 |
+
for j in range(i+1, len(labels)):
|
300 |
+
if labels[i] == list(class_dict.values()).index('sequenceFlow') and labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
301 |
+
if links[i] == links[j]:
|
302 |
+
print(f'element {i} and {j} have the same links')
|
303 |
+
if scores[i] > scores[j]:
|
304 |
+
print('delete element', j)
|
305 |
+
delete_elements.append(j)
|
306 |
+
else:
|
307 |
+
print('delete element', i)
|
308 |
+
delete_elements.append(i)
|
309 |
+
|
310 |
+
boxes = np.delete(boxes, delete_elements, axis=0)
|
311 |
+
labels = np.delete(labels, delete_elements)
|
312 |
+
scores = np.delete(scores, delete_elements)
|
313 |
+
keypoints = np.delete(keypoints, delete_elements, axis=0)
|
314 |
+
links = np.delete(links, delete_elements, axis=0)
|
315 |
+
best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
|
316 |
+
|
317 |
+
#also delete the element in the pool_dict
|
318 |
+
for pool_index, elements in pool_dict.items():
|
319 |
+
pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
|
320 |
+
|
321 |
+
return boxes, labels, scores, keypoints, links, best_points, pool_dict
|
322 |
+
|
323 |
+
def give_link_to_element(links, labels):
|
324 |
+
#give a link to event to allow the creation of the BPMN id with start, indermediate and end event
|
325 |
+
for i in range(len(links)):
|
326 |
+
if labels[i] == list(class_dict.values()).index('sequenceFlow'):
|
327 |
+
id1, id2 = links[i]
|
328 |
+
if (id1 and id2) is not None:
|
329 |
+
links[id1][1] = i
|
330 |
+
links[id2][0] = i
|
331 |
+
return links
|
332 |
+
|
333 |
+
def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15):
|
334 |
+
model_object.eval() # Set the model to evaluation mode
|
335 |
+
model_arrow.eval() # Set the model to evaluation mode
|
336 |
+
|
337 |
+
# Load an image
|
338 |
+
with torch.no_grad(): # Disable gradient calculation for inference
|
339 |
+
_, objects_pred = object_prediction(model_object, image, score_threshold=score_threshold, iou_threshold=iou_threshold)
|
340 |
+
_, arrow_pred = arrow_prediction(model_arrow, image, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
|
341 |
+
|
342 |
+
#print('Object prediction:', objects_pred)
|
343 |
+
|
344 |
+
|
345 |
+
boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
|
346 |
+
|
347 |
+
# Regroup elements by pool
|
348 |
+
pool_dict, labels = regroup_elements_by_pool(boxes,labels, class_dict)
|
349 |
+
# Create links between elements
|
350 |
+
flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
|
351 |
+
#Correct the labels of some sequenceflow that cross multiple pool
|
352 |
+
labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
|
353 |
+
#give a link to event to allow the creation of the BPMN id with start, indermediate and end event
|
354 |
+
flow_links = give_link_to_element(flow_links, labels)
|
355 |
+
|
356 |
+
boxes,labels,scores,keypoints,flow_links,best_points,pool_dict = last_correction(boxes,labels,scores,keypoints,flow_links,best_points, pool_dict)
|
357 |
+
|
358 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
359 |
+
image = (image * 255).astype(np.uint8)
|
360 |
+
idx = []
|
361 |
+
for i in range(len(labels)):
|
362 |
+
idx.append(i)
|
363 |
+
bpmn_id = [class_dict[labels[i]] for i in range(len(labels))]
|
364 |
+
|
365 |
+
data = {
|
366 |
+
'image': image,
|
367 |
+
'idx': idx,
|
368 |
+
'boxes': boxes,
|
369 |
+
'labels': labels,
|
370 |
+
'scores': scores,
|
371 |
+
'keypoints': keypoints,
|
372 |
+
'links': flow_links,
|
373 |
+
'best_points': best_points,
|
374 |
+
'pool_dict': pool_dict,
|
375 |
+
'BPMN_id': bpmn_id,
|
376 |
+
}
|
377 |
+
|
378 |
+
# give a unique BPMN id to each element
|
379 |
+
data = create_BPMN_id(data)
|
380 |
+
|
381 |
+
|
382 |
+
|
383 |
+
return image, data
|
384 |
+
|
385 |
+
def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, model_dict, iou_threshold=0.5):
|
386 |
+
# Initialize dictionaries to hold per-class counts
|
387 |
+
class_tp = {cls: 0 for cls in model_dict.values()}
|
388 |
+
class_fp = {cls: 0 for cls in model_dict.values()}
|
389 |
+
class_fn = {cls: 0 for cls in model_dict.values()}
|
390 |
+
|
391 |
+
# Track which true boxes have been matched
|
392 |
+
matched = [False] * len(true_boxes)
|
393 |
+
|
394 |
+
# Check each prediction against true boxes
|
395 |
+
for pred_box, pred_label in zip(pred_boxes, pred_labels):
|
396 |
+
match_found = False
|
397 |
+
for idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)):
|
398 |
+
if not matched[idx] and pred_label == true_label:
|
399 |
+
if iou(np.array(pred_box), np.array(true_box)) >= iou_threshold:
|
400 |
+
class_tp[model_dict[pred_label]] += 1
|
401 |
+
matched[idx] = True
|
402 |
+
match_found = True
|
403 |
+
break
|
404 |
+
if not match_found:
|
405 |
+
class_fp[model_dict[pred_label]] += 1
|
406 |
+
|
407 |
+
# Count false negatives
|
408 |
+
for idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)):
|
409 |
+
if not matched[idx]:
|
410 |
+
class_fn[model_dict[true_label]] += 1
|
411 |
+
|
412 |
+
# Calculate precision, recall, and F1-score per class
|
413 |
+
class_precision = {}
|
414 |
+
class_recall = {}
|
415 |
+
class_f1_score = {}
|
416 |
+
|
417 |
+
for cls in model_dict.values():
|
418 |
+
precision = class_tp[cls] / (class_tp[cls] + class_fp[cls]) if class_tp[cls] + class_fp[cls] > 0 else 0
|
419 |
+
recall = class_tp[cls] / (class_tp[cls] + class_fn[cls]) if class_tp[cls] + class_fn[cls] > 0 else 0
|
420 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
|
421 |
+
|
422 |
+
class_precision[cls] = precision
|
423 |
+
class_recall[cls] = recall
|
424 |
+
class_f1_score[cls] = f1_score
|
425 |
+
|
426 |
+
return class_precision, class_recall, class_f1_score
|
427 |
+
|
428 |
+
|
429 |
+
def keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold=5):
|
430 |
+
result = 0
|
431 |
+
reverted = False
|
432 |
+
#find the position of keypoints in the list
|
433 |
+
idx = np.where(pred_boxes == pred_box)[0][0]
|
434 |
+
idx2 = np.where(true_boxes == true_box)[0][0]
|
435 |
+
|
436 |
+
keypoint1_pred = pred_keypoints[idx][0]
|
437 |
+
keypoint1_true = true_keypoints[idx2][0]
|
438 |
+
keypoint2_pred = pred_keypoints[idx][1]
|
439 |
+
keypoint2_true = true_keypoints[idx2][1]
|
440 |
+
|
441 |
+
distance1 = np.linalg.norm(keypoint1_pred[:2] - keypoint1_true[:2])
|
442 |
+
distance2 = np.linalg.norm(keypoint2_pred[:2] - keypoint2_true[:2])
|
443 |
+
distance3 = np.linalg.norm(keypoint1_pred[:2] - keypoint2_true[:2])
|
444 |
+
distance4 = np.linalg.norm(keypoint2_pred[:2] - keypoint1_true[:2])
|
445 |
+
|
446 |
+
if distance1 < distance_threshold:
|
447 |
+
result += 1
|
448 |
+
if distance2 < distance_threshold:
|
449 |
+
result += 1
|
450 |
+
if distance3 < distance_threshold or distance4 < distance_threshold:
|
451 |
+
reverted = True
|
452 |
+
|
453 |
+
return result, reverted
|
454 |
+
|
455 |
+
def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred_keypoints, true_keypoints, iou_threshold=0.5, distance_threshold=5):
|
456 |
+
tp, fp, fn = 0, 0, 0
|
457 |
+
key_t, key_f = 0, 0
|
458 |
+
labels_t, labels_f = 0, 0
|
459 |
+
reverted_tot = 0
|
460 |
+
|
461 |
+
matched_true_boxes = set()
|
462 |
+
for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
|
463 |
+
match_found = False
|
464 |
+
for true_idx, true_box in enumerate(true_boxes):
|
465 |
+
if true_idx in matched_true_boxes:
|
466 |
+
continue
|
467 |
+
iou_val = iou(pred_box, true_box)
|
468 |
+
if iou_val >= iou_threshold:
|
469 |
+
if true_keypoints is not None and pred_keypoints is not None:
|
470 |
+
key_result, reverted = keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold)
|
471 |
+
key_t += key_result
|
472 |
+
key_f += 2 - key_result
|
473 |
+
if reverted:
|
474 |
+
reverted_tot += 1
|
475 |
+
|
476 |
+
match_found = True
|
477 |
+
matched_true_boxes.add(true_idx)
|
478 |
+
if pred_label == true_labels[true_idx]:
|
479 |
+
labels_t += 1
|
480 |
+
else:
|
481 |
+
labels_f += 1
|
482 |
+
tp += 1
|
483 |
+
break
|
484 |
+
if not match_found:
|
485 |
+
fp += 1
|
486 |
+
|
487 |
+
fn = len(true_boxes) - tp
|
488 |
+
|
489 |
+
return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted_tot
|
490 |
+
|
491 |
+
|
492 |
+
def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'):
|
493 |
+
model.eval()
|
494 |
+
tp, fp, fn = 0, 0, 0
|
495 |
+
labels_t, labels_f = 0, 0
|
496 |
+
key_t, key_f = 0, 0
|
497 |
+
reverted = 0
|
498 |
+
|
499 |
+
with torch.no_grad():
|
500 |
+
for images, targets_im in tqdm(loader, desc="Testing... "): # Wrap the loader with tqdm
|
501 |
+
devices = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
502 |
+
images = [image.to(devices) for image in images]
|
503 |
+
targets = [{k: v.clone().detach().to(devices) for k, v in t.items()} for t in targets_im]
|
504 |
+
|
505 |
+
predictions = model(images)
|
506 |
+
|
507 |
+
for target, prediction in zip(targets, predictions):
|
508 |
+
true_boxes = target['boxes'].cpu().numpy()
|
509 |
+
true_labels = target['labels'].cpu().numpy()
|
510 |
+
if 'keypoints' in target:
|
511 |
+
true_keypoints = target['keypoints'].cpu().numpy()
|
512 |
+
|
513 |
+
pred_boxes = prediction['boxes'].cpu().numpy()
|
514 |
+
scores = prediction['scores'].cpu().numpy()
|
515 |
+
pred_labels = prediction['labels'].cpu().numpy()
|
516 |
+
if 'keypoints' in prediction:
|
517 |
+
pred_keypoints = prediction['keypoints'].cpu().numpy()
|
518 |
+
|
519 |
+
selected_boxes = non_maximum_suppression(pred_boxes, scores, iou_threshold=iou_threshold)
|
520 |
+
pred_boxes = pred_boxes[selected_boxes]
|
521 |
+
scores = scores[selected_boxes]
|
522 |
+
pred_labels = pred_labels[selected_boxes]
|
523 |
+
if 'keypoints' in prediction:
|
524 |
+
pred_keypoints = pred_keypoints[selected_boxes]
|
525 |
+
|
526 |
+
filtered_boxes = []
|
527 |
+
filtered_labels = []
|
528 |
+
filtered_keypoints = []
|
529 |
+
if 'keypoints' not in prediction:
|
530 |
+
#create a list of zeros of length equal to the number of boxes
|
531 |
+
pred_keypoints = [np.zeros((2, 3)) for _ in range(len(pred_boxes))]
|
532 |
+
|
533 |
+
for box, score, label, keypoints in zip(pred_boxes, scores, pred_labels, pred_keypoints):
|
534 |
+
if score >= score_threshold:
|
535 |
+
filtered_boxes.append(box)
|
536 |
+
filtered_labels.append(label)
|
537 |
+
if 'keypoints' in prediction:
|
538 |
+
filtered_keypoints.append(keypoints)
|
539 |
+
|
540 |
+
if key_correction and ('keypoints' in prediction):
|
541 |
+
filtered_keypoints = keypoint_correction(filtered_keypoints, filtered_boxes, filtered_labels)
|
542 |
+
|
543 |
+
if 'keypoints' not in target:
|
544 |
+
filtered_keypoints = None
|
545 |
+
true_keypoints = None
|
546 |
+
tp_img, fp_img, fn_img, labels_t_img, labels_f_img, key_t_img, key_f_img, reverted_img = evaluate_single_image(
|
547 |
+
filtered_boxes, true_boxes, filtered_labels, true_labels, filtered_keypoints, true_keypoints, iou_threshold, distance_threshold)
|
548 |
+
|
549 |
+
tp += tp_img
|
550 |
+
fp += fp_img
|
551 |
+
fn += fn_img
|
552 |
+
labels_t += labels_t_img
|
553 |
+
labels_f += labels_f_img
|
554 |
+
key_t += key_t_img
|
555 |
+
key_f += key_f_img
|
556 |
+
reverted += reverted_img
|
557 |
+
|
558 |
+
return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted
|
559 |
+
|
560 |
+
def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type = 'object'):
|
561 |
+
|
562 |
+
tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted = pred_4_evaluation(model, test_loader, score_threshold, iou_threshold, distance_threshold, key_correction, model_type)
|
563 |
+
|
564 |
+
labels_precision = labels_t / (labels_t + labels_f) if (labels_t + labels_f) > 0 else 0
|
565 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
566 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
567 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
568 |
+
if model_type == 'arrow':
|
569 |
+
key_accuracy = key_t / (key_t + key_f) if (key_t + key_f) > 0 else 0
|
570 |
+
reverted_accuracy = reverted / (key_t + key_f) if (key_t + key_f) > 0 else 0
|
571 |
+
else:
|
572 |
+
key_accuracy = 0
|
573 |
+
reverted_accuracy = 0
|
574 |
+
|
575 |
+
return labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy
|
576 |
+
|
577 |
+
|
578 |
+
|
579 |
+
def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold=0.5):
|
580 |
+
matched_true_boxes = set()
|
581 |
+
for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
|
582 |
+
match_found = False
|
583 |
+
for true_idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)):
|
584 |
+
if true_idx in matched_true_boxes:
|
585 |
+
continue
|
586 |
+
if pred_label == true_label and iou(np.array(pred_box), np.array(true_box)) >= iou_threshold:
|
587 |
+
class_tp[model_dict[pred_label]] += 1
|
588 |
+
matched_true_boxes.add(true_idx)
|
589 |
+
match_found = True
|
590 |
+
break
|
591 |
+
if not match_found:
|
592 |
+
class_fp[model_dict[pred_label]] += 1
|
593 |
+
|
594 |
+
for idx, true_label in enumerate(true_labels):
|
595 |
+
if idx not in matched_true_boxes:
|
596 |
+
class_fn[model_dict[true_label]] += 1
|
597 |
+
|
598 |
+
def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshold=0.5):
|
599 |
+
model.eval()
|
600 |
+
with torch.no_grad():
|
601 |
+
for images, targets_im in tqdm(loader, desc="Testing... "):
|
602 |
+
devices = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
603 |
+
images = [image.to(devices) for image in images]
|
604 |
+
targets = [{k: v.clone().detach().to(devices) for k, v in t.items()} for t in targets_im]
|
605 |
+
|
606 |
+
predictions = model(images)
|
607 |
+
|
608 |
+
for target, prediction in zip(targets, predictions):
|
609 |
+
true_boxes = target['boxes'].cpu().numpy()
|
610 |
+
true_labels = target['labels'].cpu().numpy()
|
611 |
+
|
612 |
+
pred_boxes = prediction['boxes'].cpu().numpy()
|
613 |
+
scores = prediction['scores'].cpu().numpy()
|
614 |
+
pred_labels = prediction['labels'].cpu().numpy()
|
615 |
+
|
616 |
+
idx = np.where(scores > score_threshold)[0]
|
617 |
+
pred_boxes = pred_boxes[idx]
|
618 |
+
scores = scores[idx]
|
619 |
+
pred_labels = pred_labels[idx]
|
620 |
+
|
621 |
+
selected_boxes = non_maximum_suppression(pred_boxes, scores, iou_threshold=iou_threshold)
|
622 |
+
pred_boxes = pred_boxes[selected_boxes]
|
623 |
+
scores = scores[selected_boxes]
|
624 |
+
pred_labels = pred_labels[selected_boxes]
|
625 |
+
|
626 |
+
yield pred_boxes, true_boxes, pred_labels, true_labels
|
627 |
+
|
628 |
+
def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5, iou_threshold=0.5):
|
629 |
+
class_tp = {cls: 0 for cls in model_dict.values()}
|
630 |
+
class_fp = {cls: 0 for cls in model_dict.values()}
|
631 |
+
class_fn = {cls: 0 for cls in model_dict.values()}
|
632 |
+
|
633 |
+
for pred_boxes, true_boxes, pred_labels, true_labels in pred_4_evaluation_per_class(model, test_loader, score_threshold, iou_threshold):
|
634 |
+
evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold)
|
635 |
+
|
636 |
+
class_precision = {}
|
637 |
+
class_recall = {}
|
638 |
+
class_f1_score = {}
|
639 |
+
|
640 |
+
for cls in model_dict.values():
|
641 |
+
precision = class_tp[cls] / (class_tp[cls] + class_fp[cls]) if class_tp[cls] + class_fp[cls] > 0 else 0
|
642 |
+
recall = class_tp[cls] / (class_tp[cls] + class_fn[cls]) if class_tp[cls] + class_fn[cls] > 0 else 0
|
643 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
|
644 |
+
|
645 |
+
class_precision[cls] = precision
|
646 |
+
class_recall[cls] = recall
|
647 |
+
class_f1_score[cls] = f1_score
|
648 |
+
|
649 |
+
return class_precision, class_recall, class_f1_score
|
flask.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask
|
2 |
+
app = Flask(__name__)
|
3 |
+
|
4 |
+
@app.route("/")
|
5 |
+
def hello():
|
6 |
+
return "Hello World!\n"
|
htlm_webpage.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import streamlit.components.v1 as components
|
3 |
+
|
4 |
+
def display_bpmn_xml(bpmn_xml):
|
5 |
+
html_template = f"""
|
6 |
+
<!DOCTYPE html>
|
7 |
+
<html>
|
8 |
+
<head>
|
9 |
+
<meta charset="UTF-8">
|
10 |
+
<title>BPMN Modeler</title>
|
11 |
+
<link rel="stylesheet" href="https://unpkg.com/bpmn-js/dist/assets/diagram-js.css">
|
12 |
+
<link rel="stylesheet" href="https://unpkg.com/bpmn-js/dist/assets/bpmn-font/css/bpmn-embedded.css">
|
13 |
+
<script src="https://unpkg.com/bpmn-js/dist/bpmn-modeler.development.js"></script>
|
14 |
+
<style>
|
15 |
+
html, body {{
|
16 |
+
height: 100%;
|
17 |
+
padding: 0;
|
18 |
+
margin: 0;
|
19 |
+
font-family: Arial, sans-serif;
|
20 |
+
display: flex;
|
21 |
+
flex-direction: column;
|
22 |
+
overflow: hidden;
|
23 |
+
}}
|
24 |
+
#button-container {{
|
25 |
+
padding: 10px;
|
26 |
+
background-color: #ffffff;
|
27 |
+
border-bottom: 1px solid #ddd;
|
28 |
+
display: flex;
|
29 |
+
justify-content: flex-start;
|
30 |
+
gap: 10px;
|
31 |
+
}}
|
32 |
+
#save-button, #download-button {{
|
33 |
+
background-color: #4CAF50;
|
34 |
+
color: white;
|
35 |
+
border: none;
|
36 |
+
padding: 10px 20px;
|
37 |
+
text-align: center;
|
38 |
+
text-decoration: none;
|
39 |
+
display: inline-block;
|
40 |
+
font-size: 16px;
|
41 |
+
margin: 4px 2px;
|
42 |
+
cursor: pointer;
|
43 |
+
border-radius: 8px;
|
44 |
+
}}
|
45 |
+
#download-button {{
|
46 |
+
background-color: #008CBA;
|
47 |
+
}}
|
48 |
+
#canvas-container {{
|
49 |
+
flex: 1;
|
50 |
+
position: relative;
|
51 |
+
background-color: #FBFBFB;
|
52 |
+
overflow: hidden; /* Prevent scrolling */
|
53 |
+
display: flex;
|
54 |
+
justify-content: center;
|
55 |
+
align-items: center;
|
56 |
+
}}
|
57 |
+
#canvas {{
|
58 |
+
height: 100%;
|
59 |
+
width: 100%;
|
60 |
+
position: relative;
|
61 |
+
}}
|
62 |
+
</style>
|
63 |
+
</head>
|
64 |
+
<body>
|
65 |
+
<div id="button-container">
|
66 |
+
<button id="save-button">Save as BPMN</button>
|
67 |
+
<button id="download-button">Save as XML</button>
|
68 |
+
<button id="download-button">Save as Vizi</button>
|
69 |
+
</div>
|
70 |
+
<div id="canvas-container">
|
71 |
+
<div id="canvas"></div>
|
72 |
+
</div>
|
73 |
+
<script>
|
74 |
+
var bpmnModeler = new BpmnJS({{
|
75 |
+
container: '#canvas'
|
76 |
+
}});
|
77 |
+
|
78 |
+
async function openDiagram(bpmnXML) {{
|
79 |
+
try {{
|
80 |
+
await bpmnModeler.importXML(bpmnXML);
|
81 |
+
bpmnModeler.get('canvas').zoom('fit-viewport');
|
82 |
+
bpmnModeler.get('canvas').zoom(0.8); // Adjust this value for zooming out
|
83 |
+
}} catch (err) {{
|
84 |
+
console.error('Error rendering BPMN diagram', err);
|
85 |
+
}}
|
86 |
+
}}
|
87 |
+
|
88 |
+
async function saveDiagram() {{
|
89 |
+
try {{
|
90 |
+
const result = await bpmnModeler.saveXML({{ format: true }});
|
91 |
+
const xml = result.xml;
|
92 |
+
const blob = new Blob([xml], {{ type: 'text/xml' }});
|
93 |
+
const url = URL.createObjectURL(blob);
|
94 |
+
const a = document.createElement('a');
|
95 |
+
a.href = url;
|
96 |
+
a.download = 'diagram.bpmn';
|
97 |
+
document.body.appendChild(a);
|
98 |
+
a.click();
|
99 |
+
document.body.removeChild(a);
|
100 |
+
}} catch (err) {{
|
101 |
+
console.error('Error saving BPMN diagram', err);
|
102 |
+
}}
|
103 |
+
}}
|
104 |
+
|
105 |
+
async function downloadXML() {{
|
106 |
+
const xml = `{bpmn_xml}`;
|
107 |
+
const blob = new Blob([xml], {{ type: 'text/xml' }});
|
108 |
+
const url = URL.createObjectURL(blob);
|
109 |
+
const a = document.createElement('a');
|
110 |
+
a.href = url;
|
111 |
+
a.download = 'diagram.xml';
|
112 |
+
document.body.appendChild(a);
|
113 |
+
a.click();
|
114 |
+
document.body.removeChild(a);
|
115 |
+
}}
|
116 |
+
|
117 |
+
document.getElementById('save-button').addEventListener('click', saveDiagram);
|
118 |
+
document.getElementById('download-button').addEventListener('click', downloadXML);
|
119 |
+
|
120 |
+
// Ensure the canvas is focused to capture keyboard events
|
121 |
+
document.getElementById('canvas').focus();
|
122 |
+
|
123 |
+
// Add event listeners for keyboard shortcuts
|
124 |
+
document.addEventListener('keydown', function(event) {{
|
125 |
+
if (event.ctrlKey && event.key === 'z') {{
|
126 |
+
bpmnModeler.get('commandStack').undo();
|
127 |
+
}} else if (event.key === 'Delete' || event.key === 'Backspace') {{
|
128 |
+
bpmnModeler.get('selection').get().forEach(function(element) {{
|
129 |
+
bpmnModeler.get('modeling').removeElements([element]);
|
130 |
+
}});
|
131 |
+
}}
|
132 |
+
}});
|
133 |
+
|
134 |
+
openDiagram(`{bpmn_xml}`);
|
135 |
+
</script>
|
136 |
+
</body>
|
137 |
+
</html>
|
138 |
+
"""
|
139 |
+
|
140 |
+
|
141 |
+
components.html(html_template, height=1000, width=1500)
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libgl1-mesa-glx
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
yamlu==0.0.17
|
2 |
+
tqdm==4.66.4
|
3 |
+
torchvision==0.18.0
|
4 |
+
azure-ai-vision-imageanalysis==1.0.0b2
|
5 |
+
streamlit==1.35.0
|
6 |
+
streamlit-image-comparison==0.0.4
|
7 |
+
streamlit-cropper==0.2.2
|
8 |
+
streamlit-drawable-canvas==0.9.3
|
9 |
+
opencv-python
|
10 |
+
gdown
|
toXML.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import xml.etree.ElementTree as ET
|
2 |
+
from utils import class_dict
|
3 |
+
|
4 |
+
def rescale(scale, boxes):
|
5 |
+
for i in range(len(boxes)):
|
6 |
+
boxes[i] = [boxes[i][0]*scale,
|
7 |
+
boxes[i][1]*scale,
|
8 |
+
boxes[i][2]*scale,
|
9 |
+
boxes[i][3]*scale]
|
10 |
+
return boxes
|
11 |
+
|
12 |
+
def create_BPMN_id(data):
|
13 |
+
enums = {
|
14 |
+
'end_event': 1,
|
15 |
+
'start_event': 1,
|
16 |
+
'task': 1,
|
17 |
+
'sequenceFlow': 1,
|
18 |
+
'messageFlow': 1,
|
19 |
+
'message_event': 1,
|
20 |
+
'exclusiveGateway': 1,
|
21 |
+
'parallelGateway': 1,
|
22 |
+
'dataAssociation': 1,
|
23 |
+
'pool': 1,
|
24 |
+
'dataObject': 1,
|
25 |
+
'timerEvent': 1
|
26 |
+
}
|
27 |
+
|
28 |
+
BPMN_name = [class_dict[label] for label in data['labels']]
|
29 |
+
|
30 |
+
for idx, Bpmn_id in enumerate(BPMN_name):
|
31 |
+
if Bpmn_id == 'event':
|
32 |
+
if data['links'][idx][0] is not None and data['links'][idx][1] is None:
|
33 |
+
key = 'end_event'
|
34 |
+
elif data['links'][idx][0] is None and data['links'][idx][1] is not None:
|
35 |
+
key = 'start_event'
|
36 |
+
else:
|
37 |
+
key = {
|
38 |
+
'task': 'task',
|
39 |
+
'dataObject': 'dataObject',
|
40 |
+
'sequenceFlow': 'sequenceFlow',
|
41 |
+
'messageFlow': 'messageFlow',
|
42 |
+
'messageEvent': 'message_event',
|
43 |
+
'exclusiveGateway': 'exclusiveGateway',
|
44 |
+
'parallelGateway': 'parallelGateway',
|
45 |
+
'dataAssociation': 'dataAssociation',
|
46 |
+
'pool': 'pool',
|
47 |
+
'timerEvent': 'timerEvent'
|
48 |
+
}.get(Bpmn_id, None)
|
49 |
+
|
50 |
+
if key:
|
51 |
+
data['BPMN_id'][idx] = f'{key}_{enums[key]}'
|
52 |
+
enums[key] += 1
|
53 |
+
|
54 |
+
return data
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
def add_diagram_elements(parent, element_id, x, y, width, height):
|
59 |
+
"""Utility to add BPMN diagram notation for elements."""
|
60 |
+
shape = ET.SubElement(parent, 'bpmndi:BPMNShape', attrib={
|
61 |
+
'bpmnElement': element_id,
|
62 |
+
'id': element_id + '_di'
|
63 |
+
})
|
64 |
+
bounds = ET.SubElement(shape, 'dc:Bounds', attrib={
|
65 |
+
'x': str(x),
|
66 |
+
'y': str(y),
|
67 |
+
'width': str(width),
|
68 |
+
'height': str(height)
|
69 |
+
})
|
70 |
+
|
71 |
+
def add_diagram_edge(parent, element_id, waypoints):
|
72 |
+
"""Utility to add BPMN diagram notation for sequence flows."""
|
73 |
+
edge = ET.SubElement(parent, 'bpmndi:BPMNEdge', attrib={
|
74 |
+
'bpmnElement': element_id,
|
75 |
+
'id': element_id + '_di'
|
76 |
+
})
|
77 |
+
for x, y in waypoints:
|
78 |
+
ET.SubElement(edge, 'di:waypoint', attrib={
|
79 |
+
'x': str(x),
|
80 |
+
'y': str(y)
|
81 |
+
})
|
82 |
+
|
83 |
+
|
84 |
+
def check_status(link, keep_elements):
|
85 |
+
if link[0] in keep_elements and link[1] in keep_elements:
|
86 |
+
return 'middle'
|
87 |
+
elif link[0] is None and link[1] in keep_elements:
|
88 |
+
return 'start'
|
89 |
+
elif link[0] in keep_elements and link[1] is None:
|
90 |
+
return 'end'
|
91 |
+
else:
|
92 |
+
return 'middle'
|
93 |
+
|
94 |
+
def check_data_association(i, links, labels, keep_elements):
|
95 |
+
for j, (k,l) in enumerate(links):
|
96 |
+
if labels[j] == 14:
|
97 |
+
if k==i:
|
98 |
+
return 'output',j
|
99 |
+
elif l==i:
|
100 |
+
return 'input',j
|
101 |
+
|
102 |
+
return 'no association', None
|
103 |
+
|
104 |
+
def create_data_Association(bpmn,data,size,element_id,source_id,target_id):
|
105 |
+
waypoints = calculate_waypoints(data, size, source_id, target_id)
|
106 |
+
add_diagram_edge(bpmn, element_id, waypoints)
|
107 |
+
|
108 |
+
# Function to dynamically create and layout BPMN elements
|
109 |
+
def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data, keep_elements):
|
110 |
+
elements = data['BPMN_id']
|
111 |
+
positions = data['boxes']
|
112 |
+
links = data['links']
|
113 |
+
|
114 |
+
for i in keep_elements:
|
115 |
+
element_id = elements[i]
|
116 |
+
if element_id is None:
|
117 |
+
continue
|
118 |
+
|
119 |
+
element_type = element_id.split('_')[0]
|
120 |
+
x, y = positions[i][:2]
|
121 |
+
|
122 |
+
# Start Event
|
123 |
+
if element_type == 'start':
|
124 |
+
element = ET.SubElement(process, 'bpmn:startEvent', id=element_id, name=text_mapping[element_id])
|
125 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['start'][0], size['start'][1])
|
126 |
+
|
127 |
+
# Task
|
128 |
+
elif element_type == 'task':
|
129 |
+
element = ET.SubElement(process, 'bpmn:task', id=element_id, name=text_mapping[element_id])
|
130 |
+
status, dataAssociation_idx = check_data_association(i, data['links'], data['labels'], keep_elements)
|
131 |
+
|
132 |
+
# Handle Data Input Association
|
133 |
+
if status == 'input':
|
134 |
+
dataObject_idx = links[dataAssociation_idx][0]
|
135 |
+
dataObject_name = elements[dataObject_idx]
|
136 |
+
dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
|
137 |
+
sub_element = ET.SubElement(element, 'bpmn:dataInputAssociation', id=f'dataInputAssociation_{dataObject_ref.split("_")[1]}')
|
138 |
+
ET.SubElement(sub_element, 'bpmn:sourceRef').text = dataObject_ref
|
139 |
+
create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], dataObject_name, element_id)
|
140 |
+
|
141 |
+
# Handle Data Output Association
|
142 |
+
elif status == 'output':
|
143 |
+
dataObject_idx = links[dataAssociation_idx][1]
|
144 |
+
dataObject_name = elements[dataObject_idx]
|
145 |
+
dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
|
146 |
+
sub_element = ET.SubElement(element, 'bpmn:dataOutputAssociation', id=f'dataOutputAssociation_{dataObject_ref.split("_")[1]}')
|
147 |
+
ET.SubElement(sub_element, 'bpmn:targetRef').text = dataObject_ref
|
148 |
+
create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], element_id, dataObject_name)
|
149 |
+
|
150 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['task'][0], size['task'][1])
|
151 |
+
|
152 |
+
# Message Events (Start, Intermediate, End)
|
153 |
+
elif element_type == 'message':
|
154 |
+
status = check_status(links[i], keep_elements)
|
155 |
+
if status == 'start':
|
156 |
+
element = ET.SubElement(process, 'bpmn:startEvent', id=element_id, name=text_mapping[element_id])
|
157 |
+
elif status == 'middle':
|
158 |
+
element = ET.SubElement(process, 'bpmn:intermediateCatchEvent', id=element_id, name=text_mapping[element_id])
|
159 |
+
elif status == 'end':
|
160 |
+
element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
|
161 |
+
ET.SubElement(element, 'bpmn:messageEventDefinition', id=f'MessageEventDefinition_{i+1}')
|
162 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['message'][0], size['message'][1])
|
163 |
+
|
164 |
+
# End Event
|
165 |
+
elif element_type == 'end':
|
166 |
+
element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
|
167 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['end'][0], size['end'][1])
|
168 |
+
|
169 |
+
# Gateways (Exclusive, Parallel)
|
170 |
+
elif element_type in ['exclusiveGateway', 'parallelGateway']:
|
171 |
+
gateway_type = 'exclusiveGateway' if element_type == 'exclusiveGateway' else 'parallelGateway'
|
172 |
+
element = ET.SubElement(process, f'bpmn:{gateway_type}', id=element_id)
|
173 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size[element_type][0], size[element_type][1])
|
174 |
+
|
175 |
+
# Data Object
|
176 |
+
elif element_type == 'dataObject':
|
177 |
+
dataObject_idx = element_id.split('_')[1]
|
178 |
+
dataObject_ref = f'DataObjectReference_{dataObject_idx}'
|
179 |
+
element = ET.SubElement(process, 'bpmn:dataObjectReference', id=dataObject_ref, dataObjectRef=element_id, name=text_mapping[element_id])
|
180 |
+
ET.SubElement(process, 'bpmn:dataObject', id=element_id)
|
181 |
+
add_diagram_elements(bpmnplane, dataObject_ref, x, y, size['dataObject'][0], size['dataObject'][1])
|
182 |
+
|
183 |
+
# Timer Event
|
184 |
+
elif element_type == 'timerEvent':
|
185 |
+
element = ET.SubElement(process, 'bpmn:intermediateCatchEvent', id=element_id, name=text_mapping[element_id])
|
186 |
+
ET.SubElement(element, 'bpmn:timerEventDefinition', id=f'TimerEventDefinition_{i+1}')
|
187 |
+
add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
# Calculate simple waypoints between two elements (this function assumes direct horizontal links for simplicity)
|
192 |
+
def calculate_waypoints(data, size, source_id, target_id):
|
193 |
+
source_idx = data['BPMN_id'].index(source_id)
|
194 |
+
target_idx = data['BPMN_id'].index(target_id)
|
195 |
+
name_source = source_id.split('_')[0]
|
196 |
+
name_target = target_id.split('_')[0]
|
197 |
+
|
198 |
+
#Get the position of the source and target
|
199 |
+
source_x, source_y = data['boxes'][source_idx][:2]
|
200 |
+
target_x, target_y = data['boxes'][target_idx][:2]
|
201 |
+
|
202 |
+
# Calculate relative position between source and target from their centers
|
203 |
+
relative_x = (target_x+size[name_target][0])/2 - (source_x+size[name_source][0])/2
|
204 |
+
relative_y = (target_y+size[name_target][1])/2 - (source_y+size[name_source][1])/2
|
205 |
+
|
206 |
+
# Get the size of the elements
|
207 |
+
size_x_source = size[name_source][0]
|
208 |
+
size_y_source = size[name_source][1]
|
209 |
+
size_x_target = size[name_target][0]
|
210 |
+
size_y_target = size[name_target][1]
|
211 |
+
|
212 |
+
#if it going to right
|
213 |
+
if relative_x >= size[name_source][0]:
|
214 |
+
source_x += size_x_source
|
215 |
+
source_y += size_y_source / 2
|
216 |
+
target_x = target_x
|
217 |
+
target_y += size_y_target / 2
|
218 |
+
#if the source is going up
|
219 |
+
if relative_y < -size[name_source][1]:
|
220 |
+
source_x -= size_x_source / 2
|
221 |
+
source_y -= size_y_source / 2
|
222 |
+
#if the source is going down
|
223 |
+
elif relative_y > size[name_source][1]:
|
224 |
+
source_x -= size_x_source / 2
|
225 |
+
source_y += size_y_source / 2
|
226 |
+
#if it going to left
|
227 |
+
elif relative_x < -size[name_source][0]:
|
228 |
+
source_x = source_x
|
229 |
+
source_y += size_y_source / 2
|
230 |
+
target_x += size_x_target
|
231 |
+
target_y += size_y_target / 2
|
232 |
+
#if the source is going up
|
233 |
+
if relative_y < -size[name_source][1]:
|
234 |
+
source_x += size_x_source / 2
|
235 |
+
source_y -= size_y_source / 2
|
236 |
+
#if the source is going down
|
237 |
+
elif relative_y > size[name_source][1]:
|
238 |
+
source_x += size_x_source / 2
|
239 |
+
source_y += size_y_source / 2
|
240 |
+
#if it going up and down
|
241 |
+
elif -size[name_source][0] < relative_x < size[name_source][0]:
|
242 |
+
source_x += size_x_source / 2
|
243 |
+
target_x += size_x_target / 2
|
244 |
+
#if it's going down
|
245 |
+
if relative_y >= size[name_source][1]/2:
|
246 |
+
source_y += size_y_source
|
247 |
+
#if it's going up
|
248 |
+
elif relative_y < -size[name_source][1]/2:
|
249 |
+
source_y = source_y
|
250 |
+
target_y += size_y_target
|
251 |
+
else:
|
252 |
+
if relative_x >= 0:
|
253 |
+
source_x += size_x_source/2
|
254 |
+
source_y += size_y_source/2
|
255 |
+
target_x -= size_x_target/2
|
256 |
+
target_y += size_y_target/2
|
257 |
+
else:
|
258 |
+
source_x -= size_x_source/2
|
259 |
+
source_y += size_y_source/2
|
260 |
+
target_x += size_x_target/2
|
261 |
+
target_y += size_y_target/2
|
262 |
+
|
263 |
+
return [(source_x, source_y), (target_x, target_y)]
|
264 |
+
|
265 |
+
|
266 |
+
def calculate_pool_bounds(data, keep_elements, size):
|
267 |
+
min_x = min_y = float('10000')
|
268 |
+
max_x = max_y = float('0')
|
269 |
+
|
270 |
+
for i in keep_elements:
|
271 |
+
if i >= len(data['BPMN_id']):
|
272 |
+
print("Problem with the index")
|
273 |
+
continue
|
274 |
+
element = data['BPMN_id'][i]
|
275 |
+
if element is None or data['labels'][i] == 13 or data['labels'][i] == 14 or data['labels'][i] == 15 or data['labels'][i] == 7 or data['labels'][i] == 15:
|
276 |
+
continue
|
277 |
+
|
278 |
+
element_type = element.split('_')[0]
|
279 |
+
x, y = data['boxes'][i][:2]
|
280 |
+
element_width, element_height = size[element_type]
|
281 |
+
|
282 |
+
min_x = min(min_x, x)
|
283 |
+
min_y = min(min_y, y)
|
284 |
+
max_x = max(max_x, x + element_width)
|
285 |
+
max_y = max(max_y, y + element_height)
|
286 |
+
|
287 |
+
return min_x, min_y, max_x, max_y
|
288 |
+
|
289 |
+
|
290 |
+
def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_element, target_element):
|
291 |
+
# Get the bounding boxes of the source and target elements
|
292 |
+
source_box = data['boxes'][source_idx]
|
293 |
+
target_box = data['boxes'][target_idx]
|
294 |
+
|
295 |
+
# Get the midpoints of the source element
|
296 |
+
source_mid_x = (source_box[0] + source_box[2]) / 2
|
297 |
+
source_mid_y = (source_box[1] + source_box[3]) / 2
|
298 |
+
|
299 |
+
# Check if the connection involves a pool
|
300 |
+
if source_element == 'pool':
|
301 |
+
pool_box = source_box
|
302 |
+
element_box = (target_box[0], target_box[1], target_box[0]+size[target_element][0], target_box[1]+size[target_element][1])
|
303 |
+
element_mid_x = (element_box[0] + element_box[2]) / 2
|
304 |
+
element_mid_y = (element_box[1] + element_box[3]) / 2
|
305 |
+
# Connect the pool's bottom or top side to the target element's top or bottom center
|
306 |
+
if pool_box[3] < element_box[1]: # Pool is above the target element
|
307 |
+
waypoints = [(element_mid_x, pool_box[3]-50), (element_mid_x, element_box[1])]
|
308 |
+
else: # Pool is below the target element
|
309 |
+
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1]-50)]
|
310 |
+
else:
|
311 |
+
pool_box = target_box
|
312 |
+
element_box = (source_box[0], source_box[1], source_box[0]+size[source_element][0], source_box[1]+size[source_element][1])
|
313 |
+
element_mid_x = (element_box[0] + element_box[2]) / 2
|
314 |
+
element_mid_y = (element_box[1] + element_box[3]) / 2
|
315 |
+
|
316 |
+
# Connect the element's bottom or top center to the pool's top or bottom side
|
317 |
+
if pool_box[3] < element_box[1]: # Pool is above the target element
|
318 |
+
waypoints = [(element_mid_x, element_box[1]), (element_mid_x, pool_box[3]-50)]
|
319 |
+
else: # Pool is below the target element
|
320 |
+
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1]-50)]
|
321 |
+
|
322 |
+
return waypoints
|
323 |
+
|
324 |
+
|
325 |
+
|
326 |
+
def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
|
327 |
+
source_idx, target_idx = data['links'][idx]
|
328 |
+
source_id, target_id = data['BPMN_id'][source_idx], data['BPMN_id'][target_idx]
|
329 |
+
if message:
|
330 |
+
element_id = f'messageflow_{source_id}_{target_id}'
|
331 |
+
else:
|
332 |
+
element_id = f'sequenceflow_{source_id}_{target_id}'
|
333 |
+
|
334 |
+
if source_id.split('_')[0] == 'pool' or target_id.split('_')[0] == 'pool':
|
335 |
+
waypoints = calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_id.split('_')[0], target_id.split('_')[0])
|
336 |
+
#waypoints = data['best_points'][idx]
|
337 |
+
if source_id.split('_')[0] == 'pool':
|
338 |
+
source_id = f"participant_{source_id.split('_')[1]}"
|
339 |
+
if target_id.split('_')[0] == 'pool':
|
340 |
+
target_id = f"participant_{target_id.split('_')[1]}"
|
341 |
+
else:
|
342 |
+
waypoints = calculate_waypoints(data, size, source_id, target_id)
|
343 |
+
#waypoints = data['best_points'][idx]
|
344 |
+
|
345 |
+
#waypoints = data['best_points'][idx]
|
346 |
+
if message:
|
347 |
+
element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
348 |
+
else:
|
349 |
+
element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
350 |
+
add_diagram_edge(bpmn, element_id, waypoints)
|
351 |
+
|
train.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms.functional as F
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from eval import main_evaluation
|
11 |
+
from torch.optim import SGD, AdamW
|
12 |
+
from torch.utils.data import DataLoader, Dataset, Subset, ConcatDataset
|
13 |
+
from torch.utils.data.dataloader import default_collate
|
14 |
+
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
15 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
16 |
+
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
|
17 |
+
from tqdm import tqdm
|
18 |
+
from utils import write_results
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
def get_arrow_model(num_classes, num_keypoints=2):
|
24 |
+
"""
|
25 |
+
Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
- num_classes (int): Number of classes for the model to detect, excluding the background class.
|
29 |
+
- num_keypoints (int): Number of keypoints to predict for each detected object.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
- model (torch.nn.Module): The modified Keypoint R-CNN model.
|
33 |
+
|
34 |
+
Steps:
|
35 |
+
1. Load a pre-trained Keypoint R-CNN model with a ResNet-50 backbone and Feature Pyramid Network (FPN).
|
36 |
+
The model is initially configured for the COCO dataset, which includes various object classes and keypoints.
|
37 |
+
2. Replace the box predictor to adjust the number of output classes. The box predictor is responsible for
|
38 |
+
classifying detected regions and predicting their bounding boxes.
|
39 |
+
3. Replace the keypoint predictor to adjust the number of keypoints the model predicts for each object.
|
40 |
+
This is necessary to tailor the model to specific tasks that may have different keypoint structures.
|
41 |
+
"""
|
42 |
+
# Load a model pre-trained on COCO, initialized without pre-trained weights
|
43 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
44 |
+
if device == torch.device('cuda'):
|
45 |
+
model = keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.COCO_V1)
|
46 |
+
else:
|
47 |
+
model = keypointrcnn_resnet50_fpn(weights=False)
|
48 |
+
|
49 |
+
# Get the number of input features for the classifier in the box predictor.
|
50 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
51 |
+
|
52 |
+
# Replace the box predictor in the ROI heads with a new one, tailored to the number of classes.
|
53 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
54 |
+
|
55 |
+
# Replace the keypoint predictor in the ROI heads with a new one, specifically designed for the desired number of keypoints.
|
56 |
+
model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(512, num_keypoints)
|
57 |
+
|
58 |
+
return model
|
59 |
+
|
60 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
61 |
+
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
|
62 |
+
def get_faster_rcnn_model(num_classes):
|
63 |
+
"""
|
64 |
+
Configures and returns a modified Faster R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes.
|
65 |
+
|
66 |
+
Parameters:
|
67 |
+
- num_classes (int): Number of classes for the model to detect, including the background class.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
- model (torch.nn.Module): The modified Faster R-CNN model.
|
71 |
+
"""
|
72 |
+
# Load a pre-trained Faster R-CNN model
|
73 |
+
model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1)
|
74 |
+
|
75 |
+
# Get the number of input features for the classifier in the box predictor
|
76 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
77 |
+
|
78 |
+
# Replace the box predictor with a new one, tailored to the number of classes (num_classes includes the background)
|
79 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
80 |
+
|
81 |
+
return model
|
82 |
+
|
83 |
+
def prepare_model(dict,opti,learning_rate= 0.0003,model_to_load=None, model_type = 'object'):
|
84 |
+
# Adjusted to pass the class_dict directly
|
85 |
+
if model_type == 'object':
|
86 |
+
model = get_faster_rcnn_model(len(dict))
|
87 |
+
elif model_type == 'arrow':
|
88 |
+
model = get_arrow_model(len(dict),2)
|
89 |
+
|
90 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
91 |
+
# Load the model weights
|
92 |
+
if model_to_load:
|
93 |
+
model.load_state_dict(torch.load('./models/'+ model_to_load +'.pth', map_location=device))
|
94 |
+
print(f"Model '{model_to_load}' loaded")
|
95 |
+
|
96 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
97 |
+
model.to(device)
|
98 |
+
|
99 |
+
if opti == 'SGD':
|
100 |
+
#learning_rate= 0.002
|
101 |
+
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
|
102 |
+
elif opti == 'Adam':
|
103 |
+
#learning_rate = 0.0003
|
104 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.00056, eps=1e-08, betas=(0.9, 0.999))
|
105 |
+
else:
|
106 |
+
print('Optimizer not found')
|
107 |
+
|
108 |
+
return model, optimizer, device
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
|
114 |
+
model.train() # Set the model to evaluation mode
|
115 |
+
total_loss = 0
|
116 |
+
|
117 |
+
# Initialize lists to keep track of individual losses
|
118 |
+
loss_classifier_list = []
|
119 |
+
loss_box_reg_list = []
|
120 |
+
loss_objectness_list = []
|
121 |
+
loss_rpn_box_reg_list = []
|
122 |
+
loss_keypoints_list = []
|
123 |
+
|
124 |
+
with torch.no_grad(): # Disable gradient computation
|
125 |
+
for images, targets_im in tqdm(data_loader, desc="Evaluating"):
|
126 |
+
images = [image.to(device) for image in images]
|
127 |
+
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
|
128 |
+
|
129 |
+
loss_dict = model(images, targets)
|
130 |
+
|
131 |
+
# Calculate the total loss for the current batch
|
132 |
+
losses = 0
|
133 |
+
if loss_config is not None:
|
134 |
+
for key, loss in loss_dict.items():
|
135 |
+
if loss_config.get(key, False):
|
136 |
+
losses += loss
|
137 |
+
else:
|
138 |
+
losses = sum(loss for key, loss in loss_dict.items())
|
139 |
+
|
140 |
+
total_loss += losses.item()
|
141 |
+
|
142 |
+
# Collect individual losses
|
143 |
+
if loss_dict.get('loss_classifier') is not None:
|
144 |
+
loss_classifier_list.append(loss_dict['loss_classifier'].item())
|
145 |
+
else:
|
146 |
+
loss_classifier_list.append(0)
|
147 |
+
|
148 |
+
if loss_dict.get('loss_box_reg') is not None:
|
149 |
+
loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
|
150 |
+
else:
|
151 |
+
loss_box_reg_list.append(0)
|
152 |
+
|
153 |
+
if loss_dict.get('loss_objectness') is not None:
|
154 |
+
loss_objectness_list.append(loss_dict['loss_objectness'].item())
|
155 |
+
else:
|
156 |
+
loss_objectness_list.append(0)
|
157 |
+
|
158 |
+
if loss_dict.get('loss_rpn_box_reg') is not None:
|
159 |
+
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
|
160 |
+
else:
|
161 |
+
loss_rpn_box_reg_list.append(0)
|
162 |
+
|
163 |
+
if 'loss_keypoint' in loss_dict:
|
164 |
+
loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
|
165 |
+
else:
|
166 |
+
loss_keypoints_list.append(0)
|
167 |
+
|
168 |
+
# Calculate average loss
|
169 |
+
avg_loss = total_loss / len(data_loader)
|
170 |
+
|
171 |
+
avg_loss_classifier = np.mean(loss_classifier_list)
|
172 |
+
avg_loss_box_reg = np.mean(loss_box_reg_list)
|
173 |
+
avg_loss_objectness = np.mean(loss_objectness_list)
|
174 |
+
avg_loss_rpn_box_reg = np.mean(loss_rpn_box_reg_list)
|
175 |
+
avg_loss_keypoints = np.mean(loss_keypoints_list)
|
176 |
+
|
177 |
+
if print_losses:
|
178 |
+
print(f"Average Loss: {avg_loss:.4f}")
|
179 |
+
print(f"Average Classifier Loss: {avg_loss_classifier:.4f}")
|
180 |
+
print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}")
|
181 |
+
print(f"Average Objectness Loss: {avg_loss_objectness:.4f}")
|
182 |
+
print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}")
|
183 |
+
print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}")
|
184 |
+
|
185 |
+
return avg_loss
|
186 |
+
|
187 |
+
|
188 |
+
def training_model(num_epochs, model, data_loader, subset_test_loader,
|
189 |
+
optimizer, model_to_load=None, change_learning_rate=5, start_key=30,
|
190 |
+
batch_size=4, crop_prob=0.2, h_flip_prob=0.3, v_flip_prob=0.3,
|
191 |
+
max_rotate_deg=20, rotate_proba=0.2, blur_prob=0.2,
|
192 |
+
score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
|
193 |
+
information_training='training', start_epoch=0, loss_config=None, model_type = 'object',
|
194 |
+
eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
|
195 |
+
|
196 |
+
|
197 |
+
if loss_config is None:
|
198 |
+
print('No loss config found, all losses will be used.')
|
199 |
+
else:
|
200 |
+
#print the list of the losses that will be used
|
201 |
+
print('The following losses will be used: ', end='')
|
202 |
+
for key, value in loss_config.items():
|
203 |
+
if value:
|
204 |
+
print(key, end=", ")
|
205 |
+
print()
|
206 |
+
|
207 |
+
|
208 |
+
# Initialize lists to store epoch-wise average losses
|
209 |
+
epoch_avg_losses = []
|
210 |
+
epoch_avg_loss_classifier = []
|
211 |
+
epoch_avg_loss_box_reg = []
|
212 |
+
epoch_avg_loss_objectness = []
|
213 |
+
epoch_avg_loss_rpn_box_reg = []
|
214 |
+
epoch_avg_loss_keypoints = []
|
215 |
+
epoch_precision = []
|
216 |
+
epoch_recall = []
|
217 |
+
epoch_f1_score = []
|
218 |
+
epoch_test_loss = []
|
219 |
+
|
220 |
+
|
221 |
+
start_tot = time.time()
|
222 |
+
best_metrics = -1000
|
223 |
+
best_epoch = 0
|
224 |
+
best_model_state = None
|
225 |
+
same = 0
|
226 |
+
learning_rate = optimizer.param_groups[0]['lr']
|
227 |
+
bad_test_loss = 0
|
228 |
+
previous_test_loss = 1000
|
229 |
+
|
230 |
+
print(f"Let's go training {model_type} model with {num_epochs} epochs!")
|
231 |
+
print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, Flip prob: {h_flip_prob}, Rotate prob: {rotate_proba}, Blur prob: {blur_prob}")
|
232 |
+
|
233 |
+
for epoch in range(num_epochs):
|
234 |
+
|
235 |
+
if (epoch>0 and (epoch)%change_learning_rate == 0) or bad_test_loss>1:
|
236 |
+
learning_rate = 0.7*learning_rate
|
237 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
|
238 |
+
print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
|
239 |
+
bad_test_loss = 0
|
240 |
+
if epoch>0 and (epoch)==start_key:
|
241 |
+
print("Now it's training Keypoints also")
|
242 |
+
loss_config['loss_keypoint'] = True
|
243 |
+
for name, param in model.named_parameters():
|
244 |
+
if 'keypoint' in name:
|
245 |
+
param.requires_grad = True
|
246 |
+
|
247 |
+
model.train()
|
248 |
+
start = time.time()
|
249 |
+
total_loss = 0
|
250 |
+
|
251 |
+
# Initialize lists to keep track of individual losses
|
252 |
+
loss_classifier_list = []
|
253 |
+
loss_box_reg_list = []
|
254 |
+
loss_objectness_list = []
|
255 |
+
loss_rpn_box_reg_list = []
|
256 |
+
loss_keypoints_list = []
|
257 |
+
|
258 |
+
# Create a tqdm progress bar
|
259 |
+
progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1+start_epoch}')
|
260 |
+
|
261 |
+
for images, targets_im in progress_bar:
|
262 |
+
images = [image.to(device) for image in images]
|
263 |
+
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
|
264 |
+
|
265 |
+
optimizer.zero_grad()
|
266 |
+
|
267 |
+
loss_dict = model(images, targets)
|
268 |
+
# Inside the training loop where losses are calculated:
|
269 |
+
losses = 0
|
270 |
+
if loss_config is not None:
|
271 |
+
for key, loss in loss_dict.items():
|
272 |
+
if loss_config.get(key, False):
|
273 |
+
if key == 'loss_classifier':
|
274 |
+
loss *= 3
|
275 |
+
losses += loss
|
276 |
+
else:
|
277 |
+
losses = sum(loss for key, loss in loss_dict.items())
|
278 |
+
|
279 |
+
# Collect individual losses
|
280 |
+
if loss_dict['loss_classifier']:
|
281 |
+
loss_classifier_list.append(loss_dict['loss_classifier'].item())
|
282 |
+
else:
|
283 |
+
loss_classifier_list.append(0)
|
284 |
+
|
285 |
+
if loss_dict['loss_box_reg']:
|
286 |
+
loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
|
287 |
+
else:
|
288 |
+
loss_box_reg_list.append(0)
|
289 |
+
|
290 |
+
if loss_dict['loss_objectness']:
|
291 |
+
loss_objectness_list.append(loss_dict['loss_objectness'].item())
|
292 |
+
else:
|
293 |
+
loss_objectness_list.append(0)
|
294 |
+
|
295 |
+
if loss_dict['loss_rpn_box_reg']:
|
296 |
+
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
|
297 |
+
else:
|
298 |
+
loss_rpn_box_reg_list.append(0)
|
299 |
+
|
300 |
+
if 'loss_keypoint' in loss_dict:
|
301 |
+
loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
|
302 |
+
else:
|
303 |
+
loss_keypoints_list.append(0)
|
304 |
+
|
305 |
+
|
306 |
+
losses.backward()
|
307 |
+
optimizer.step()
|
308 |
+
|
309 |
+
total_loss += losses.item()
|
310 |
+
|
311 |
+
# Update the description with the current loss
|
312 |
+
progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}')
|
313 |
+
|
314 |
+
# Calculate average loss
|
315 |
+
avg_loss = total_loss / len(data_loader)
|
316 |
+
|
317 |
+
epoch_avg_losses.append(avg_loss)
|
318 |
+
epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
|
319 |
+
epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
|
320 |
+
epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
|
321 |
+
epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
|
322 |
+
epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
|
323 |
+
|
324 |
+
|
325 |
+
# Evaluate the model on the test set
|
326 |
+
if eval_metric != 'loss':
|
327 |
+
avg_test_loss = 0
|
328 |
+
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
|
329 |
+
print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
|
330 |
+
if eval_metric == 'all':
|
331 |
+
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
332 |
+
print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
333 |
+
if eval_metric == 'loss':
|
334 |
+
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0,0,0,0,0,0
|
335 |
+
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
336 |
+
print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
337 |
+
|
338 |
+
print(f"Time: {time.time() - start:.2f} [s]")
|
339 |
+
|
340 |
+
|
341 |
+
if epoch>0 and (epoch)%start_key == 0:
|
342 |
+
print(f"Keypoints Accuracy: {key_accuracy:.4f}", end=", ")
|
343 |
+
|
344 |
+
if eval_metric == 'f1_score':
|
345 |
+
metric_used = f1_score
|
346 |
+
elif eval_metric == 'precision':
|
347 |
+
metric_used = precision
|
348 |
+
elif eval_metric == 'recall':
|
349 |
+
metric_used = recall
|
350 |
+
else:
|
351 |
+
metric_used = -avg_test_loss
|
352 |
+
|
353 |
+
# Check if this epoch's model has the lowest average loss
|
354 |
+
if metric_used > best_metrics:
|
355 |
+
best_metrics = metric_used
|
356 |
+
best_epoch = epoch+1+start_epoch
|
357 |
+
best_model_state = copy.deepcopy(model.state_dict())
|
358 |
+
|
359 |
+
if epoch>0 and f1_score>early_stop_f1_score:
|
360 |
+
same+=1
|
361 |
+
|
362 |
+
epoch_precision.append(precision)
|
363 |
+
epoch_recall.append(recall)
|
364 |
+
epoch_f1_score.append(f1_score)
|
365 |
+
epoch_test_loss.append(avg_test_loss)
|
366 |
+
|
367 |
+
name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
|
368 |
+
|
369 |
+
if same >=1 :
|
370 |
+
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
371 |
+
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
372 |
+
write_results(name_model,metrics_list,start_epoch)
|
373 |
+
break
|
374 |
+
|
375 |
+
if (epoch+1+start_epoch) % 5 == 0:
|
376 |
+
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
377 |
+
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
378 |
+
model.load_state_dict(best_model_state)
|
379 |
+
write_results(name_model,metrics_list,start_epoch)
|
380 |
+
|
381 |
+
if avg_test_loss > previous_test_loss:
|
382 |
+
bad_test_loss += 1
|
383 |
+
previous_test_loss = avg_test_loss
|
384 |
+
|
385 |
+
|
386 |
+
print(f"\n Total time: {(time.time() - start_tot)/60} minutes, Best Epoch is {best_epoch} with an f1_score of {best_metrics:.4f}")
|
387 |
+
if best_model_state:
|
388 |
+
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
389 |
+
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
390 |
+
model.load_state_dict(best_model_state)
|
391 |
+
write_results(name_model,metrics_list,start_epoch)
|
392 |
+
print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}")
|
393 |
+
|
394 |
+
return model, metrics_list
|
utils.py
ADDED
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.models.detection import keypointrcnn_resnet50_fpn
|
2 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
3 |
+
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
|
4 |
+
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
import numpy as np
|
10 |
+
from torch.utils.data.dataloader import default_collate
|
11 |
+
import cv2
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from torch.utils.data import DataLoader, Subset, ConcatDataset
|
14 |
+
from tqdm import tqdm
|
15 |
+
from torch.optim import SGD
|
16 |
+
import time
|
17 |
+
from torch.optim import AdamW
|
18 |
+
import copy
|
19 |
+
from torchvision import transforms
|
20 |
+
|
21 |
+
|
22 |
+
object_dict = {
|
23 |
+
0: 'background',
|
24 |
+
1: 'task',
|
25 |
+
2: 'exclusiveGateway',
|
26 |
+
3: 'event',
|
27 |
+
4: 'parallelGateway',
|
28 |
+
5: 'messageEvent',
|
29 |
+
6: 'pool',
|
30 |
+
7: 'lane',
|
31 |
+
8: 'dataObject',
|
32 |
+
9: 'dataStore',
|
33 |
+
10: 'subProcess',
|
34 |
+
11: 'eventBasedGateway',
|
35 |
+
12: 'timerEvent',
|
36 |
+
}
|
37 |
+
|
38 |
+
arrow_dict = {
|
39 |
+
0: 'background',
|
40 |
+
1: 'sequenceFlow',
|
41 |
+
2: 'dataAssociation',
|
42 |
+
3: 'messageFlow',
|
43 |
+
}
|
44 |
+
|
45 |
+
class_dict = {
|
46 |
+
0: 'background',
|
47 |
+
1: 'task',
|
48 |
+
2: 'exclusiveGateway',
|
49 |
+
3: 'event',
|
50 |
+
4: 'parallelGateway',
|
51 |
+
5: 'messageEvent',
|
52 |
+
6: 'pool',
|
53 |
+
7: 'lane',
|
54 |
+
8: 'dataObject',
|
55 |
+
9: 'dataStore',
|
56 |
+
10: 'subProcess',
|
57 |
+
11: 'eventBasedGateway',
|
58 |
+
12: 'timerEvent',
|
59 |
+
13: 'sequenceFlow',
|
60 |
+
14: 'dataAssociation',
|
61 |
+
15: 'messageFlow',
|
62 |
+
}
|
63 |
+
|
64 |
+
def rescale_boxes(scale, boxes):
|
65 |
+
for i in range(len(boxes)):
|
66 |
+
boxes[i] = [boxes[i][0]*scale,
|
67 |
+
boxes[i][1]*scale,
|
68 |
+
boxes[i][2]*scale,
|
69 |
+
boxes[i][3]*scale]
|
70 |
+
return boxes
|
71 |
+
|
72 |
+
def iou(box1, box2):
|
73 |
+
# Calcule l'intersection des deux boîtes englobantes
|
74 |
+
inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
|
75 |
+
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
76 |
+
|
77 |
+
# Calcule l'union des deux boîtes englobantes
|
78 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
79 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
80 |
+
union_area = box1_area + box2_area - inter_area
|
81 |
+
|
82 |
+
return inter_area / union_area
|
83 |
+
|
84 |
+
def proportion_inside(box1, box2):
|
85 |
+
# Calculate the intersection of the two bounding boxes
|
86 |
+
inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
|
87 |
+
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
88 |
+
|
89 |
+
# Calculate the area of box1
|
90 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
91 |
+
|
92 |
+
# Calculate the proportion of box1 inside box2
|
93 |
+
if box1_area == 0:
|
94 |
+
return 0
|
95 |
+
proportion = inter_area / box1_area
|
96 |
+
|
97 |
+
# Ensure the proportion is at most 100%
|
98 |
+
return min(proportion, 1.0)
|
99 |
+
|
100 |
+
def resize_boxes(boxes, original_size, target_size):
|
101 |
+
"""
|
102 |
+
Resizes bounding boxes according to a new image size.
|
103 |
+
|
104 |
+
Parameters:
|
105 |
+
- boxes (np.array): The original bounding boxes as a numpy array of shape [N, 4].
|
106 |
+
- original_size (tuple): The original size of the image as (width, height).
|
107 |
+
- target_size (tuple): The desired size to resize the image to as (width, height).
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
- np.array: The resized bounding boxes as a numpy array of shape [N, 4].
|
111 |
+
"""
|
112 |
+
orig_width, orig_height = original_size
|
113 |
+
target_width, target_height = target_size
|
114 |
+
|
115 |
+
# Calculate the ratios for width and height
|
116 |
+
width_ratio = target_width / orig_width
|
117 |
+
height_ratio = target_height / orig_height
|
118 |
+
|
119 |
+
# Apply the ratios to the bounding boxes
|
120 |
+
boxes[:, 0] *= width_ratio
|
121 |
+
boxes[:, 1] *= height_ratio
|
122 |
+
boxes[:, 2] *= width_ratio
|
123 |
+
boxes[:, 3] *= height_ratio
|
124 |
+
|
125 |
+
return boxes
|
126 |
+
|
127 |
+
def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: tuple) -> np.ndarray:
|
128 |
+
"""
|
129 |
+
Resize keypoints based on the original and target dimensions of an image.
|
130 |
+
|
131 |
+
Parameters:
|
132 |
+
- keypoints (np.ndarray): The array of keypoints, where each keypoint is represented by its (x, y) coordinates.
|
133 |
+
- original_size (tuple): The width and height of the original image (width, height).
|
134 |
+
- target_size (tuple): The width and height of the target image (width, height).
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
- np.ndarray: The resized keypoints.
|
138 |
+
|
139 |
+
Explanation:
|
140 |
+
The function calculates the ratio of the target dimensions to the original dimensions.
|
141 |
+
It then applies these ratios to the x and y coordinates of each keypoint to scale them
|
142 |
+
appropriately to the target image size.
|
143 |
+
"""
|
144 |
+
|
145 |
+
orig_width, orig_height = original_size
|
146 |
+
target_width, target_height = target_size
|
147 |
+
|
148 |
+
# Calculate the ratios for width and height scaling
|
149 |
+
width_ratio = target_width / orig_width
|
150 |
+
height_ratio = target_height / orig_height
|
151 |
+
|
152 |
+
# Apply the scaling ratios to the x and y coordinates of each keypoint
|
153 |
+
keypoints[:, 0] *= width_ratio # Scale x coordinates
|
154 |
+
keypoints[:, 1] *= height_ratio # Scale y coordinates
|
155 |
+
|
156 |
+
return keypoints
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
class RandomCrop:
|
161 |
+
def __init__(self, new_size=(1333,800),crop_fraction=0.5, min_objects=4):
|
162 |
+
self.crop_fraction = crop_fraction
|
163 |
+
self.min_objects = min_objects
|
164 |
+
self.new_size = new_size
|
165 |
+
|
166 |
+
def __call__(self, image, target):
|
167 |
+
new_w1, new_h1 = self.new_size
|
168 |
+
w, h = image.size
|
169 |
+
new_w = int(w * self.crop_fraction)
|
170 |
+
new_h = int(new_w*new_h1/new_w1)
|
171 |
+
|
172 |
+
i=0
|
173 |
+
for i in range(4):
|
174 |
+
if new_h >= h:
|
175 |
+
i += 0.05
|
176 |
+
new_w = int(w * (self.crop_fraction - i))
|
177 |
+
new_h = int(new_w*new_h1/new_w1)
|
178 |
+
if new_h < h:
|
179 |
+
continue
|
180 |
+
|
181 |
+
if new_h >= h:
|
182 |
+
return image, target
|
183 |
+
|
184 |
+
boxes = target["boxes"]
|
185 |
+
if 'keypoints' in target:
|
186 |
+
keypoints = target["keypoints"]
|
187 |
+
else:
|
188 |
+
keypoints = []
|
189 |
+
for i in range(len(boxes)):
|
190 |
+
keypoints.append(torch.zeros((2,3)))
|
191 |
+
|
192 |
+
|
193 |
+
# Attempt to find a suitable crop region
|
194 |
+
success = False
|
195 |
+
for _ in range(100): # Max 100 attempts to find a valid crop
|
196 |
+
top = random.randint(0, h - new_h)
|
197 |
+
left = random.randint(0, w - new_w)
|
198 |
+
crop_region = [left, top, left + new_w, top + new_h]
|
199 |
+
|
200 |
+
# Check how many objects are fully contained in this region
|
201 |
+
contained_boxes = []
|
202 |
+
contained_keypoints = []
|
203 |
+
for box, kp in zip(boxes, keypoints):
|
204 |
+
if box[0] >= crop_region[0] and box[1] >= crop_region[1] and box[2] <= crop_region[2] and box[3] <= crop_region[3]:
|
205 |
+
# Adjust box and keypoints coordinates
|
206 |
+
new_box = box - torch.tensor([crop_region[0], crop_region[1], crop_region[0], crop_region[1]])
|
207 |
+
new_kp = kp - torch.tensor([crop_region[0], crop_region[1], 0])
|
208 |
+
contained_boxes.append(new_box)
|
209 |
+
contained_keypoints.append(new_kp)
|
210 |
+
|
211 |
+
if len(contained_boxes) >= self.min_objects:
|
212 |
+
success = True
|
213 |
+
break
|
214 |
+
|
215 |
+
if success:
|
216 |
+
# Perform the actual crop
|
217 |
+
image = F.crop(image, top, left, new_h, new_w)
|
218 |
+
target["boxes"] = torch.stack(contained_boxes) if contained_boxes else torch.zeros((0, 4))
|
219 |
+
if 'keypoints' in target:
|
220 |
+
target["keypoints"] = torch.stack(contained_keypoints) if contained_keypoints else torch.zeros((0, 2, 4))
|
221 |
+
|
222 |
+
return image, target
|
223 |
+
|
224 |
+
|
225 |
+
class RandomFlip:
|
226 |
+
def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
|
227 |
+
"""
|
228 |
+
Initializes the RandomFlip with probabilities for flipping.
|
229 |
+
|
230 |
+
Parameters:
|
231 |
+
- h_flip_prob (float): Probability of applying a horizontal flip to the image.
|
232 |
+
- v_flip_prob (float): Probability of applying a vertical flip to the image.
|
233 |
+
"""
|
234 |
+
self.h_flip_prob = h_flip_prob
|
235 |
+
self.v_flip_prob = v_flip_prob
|
236 |
+
|
237 |
+
def __call__(self, image, target):
|
238 |
+
"""
|
239 |
+
Applies random horizontal and/or vertical flip to the image and updates target data accordingly.
|
240 |
+
|
241 |
+
Parameters:
|
242 |
+
- image (PIL Image): The image to be flipped.
|
243 |
+
- target (dict): The target dictionary containing 'boxes' and 'keypoints'.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
- PIL Image, dict: The flipped image and its updated target dictionary.
|
247 |
+
"""
|
248 |
+
if random.random() < self.h_flip_prob:
|
249 |
+
image = F.hflip(image)
|
250 |
+
w, _ = image.size # Get the new width of the image after flip for bounding box adjustment
|
251 |
+
# Adjust bounding boxes for horizontal flip
|
252 |
+
for i, box in enumerate(target['boxes']):
|
253 |
+
xmin, ymin, xmax, ymax = box
|
254 |
+
target['boxes'][i] = torch.tensor([w - xmax, ymin, w - xmin, ymax], dtype=torch.float32)
|
255 |
+
|
256 |
+
# Adjust keypoints for horizontal flip
|
257 |
+
if 'keypoints' in target:
|
258 |
+
new_keypoints = []
|
259 |
+
for keypoints_for_object in target['keypoints']:
|
260 |
+
flipped_keypoints_for_object = []
|
261 |
+
for kp in keypoints_for_object:
|
262 |
+
x, y = kp[:2]
|
263 |
+
new_x = w - x
|
264 |
+
flipped_keypoints_for_object.append(torch.tensor([new_x, y] + list(kp[2:])))
|
265 |
+
new_keypoints.append(torch.stack(flipped_keypoints_for_object))
|
266 |
+
target['keypoints'] = torch.stack(new_keypoints)
|
267 |
+
|
268 |
+
if random.random() < self.v_flip_prob:
|
269 |
+
image = F.vflip(image)
|
270 |
+
_, h = image.size # Get the new height of the image after flip for bounding box adjustment
|
271 |
+
# Adjust bounding boxes for vertical flip
|
272 |
+
for i, box in enumerate(target['boxes']):
|
273 |
+
xmin, ymin, xmax, ymax = box
|
274 |
+
target['boxes'][i] = torch.tensor([xmin, h - ymax, xmax, h - ymin], dtype=torch.float32)
|
275 |
+
|
276 |
+
# Adjust keypoints for vertical flip
|
277 |
+
if 'keypoints' in target:
|
278 |
+
new_keypoints = []
|
279 |
+
for keypoints_for_object in target['keypoints']:
|
280 |
+
flipped_keypoints_for_object = []
|
281 |
+
for kp in keypoints_for_object:
|
282 |
+
x, y = kp[:2]
|
283 |
+
new_y = h - y
|
284 |
+
flipped_keypoints_for_object.append(torch.tensor([x, new_y] + list(kp[2:])))
|
285 |
+
new_keypoints.append(torch.stack(flipped_keypoints_for_object))
|
286 |
+
target['keypoints'] = torch.stack(new_keypoints)
|
287 |
+
|
288 |
+
return image, target
|
289 |
+
|
290 |
+
|
291 |
+
class RandomRotate:
|
292 |
+
def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
|
293 |
+
"""
|
294 |
+
Initializes the RandomRotate with a maximum rotation angle and probability of rotating.
|
295 |
+
|
296 |
+
Parameters:
|
297 |
+
- max_rotate_deg (int): Maximum degree to rotate the image.
|
298 |
+
- rotate_proba (float): Probability of applying rotation to the image.
|
299 |
+
"""
|
300 |
+
self.max_rotate_deg = max_rotate_deg
|
301 |
+
self.rotate_proba = rotate_proba
|
302 |
+
|
303 |
+
def __call__(self, image, target):
|
304 |
+
"""
|
305 |
+
Randomly rotates the image and updates the target data accordingly.
|
306 |
+
|
307 |
+
Parameters:
|
308 |
+
- image (PIL Image): The image to be rotated.
|
309 |
+
- target (dict): The target dictionary containing 'boxes', 'labels', and 'keypoints'.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
- PIL Image, dict: The rotated image and its updated target dictionary.
|
313 |
+
"""
|
314 |
+
if random.random() < self.rotate_proba:
|
315 |
+
angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
|
316 |
+
image = F.rotate(image, angle, expand=False, fill=200)
|
317 |
+
|
318 |
+
# Rotate bounding boxes
|
319 |
+
w, h = image.size
|
320 |
+
cx, cy = w / 2, h / 2
|
321 |
+
boxes = target["boxes"]
|
322 |
+
new_boxes = []
|
323 |
+
for box in boxes:
|
324 |
+
new_box = self.rotate_box(box, angle, cx, cy)
|
325 |
+
new_boxes.append(new_box)
|
326 |
+
target["boxes"] = torch.stack(new_boxes)
|
327 |
+
|
328 |
+
# Rotate keypoints
|
329 |
+
if 'keypoints' in target:
|
330 |
+
new_keypoints = []
|
331 |
+
for keypoints in target["keypoints"]:
|
332 |
+
new_kp = self.rotate_keypoints(keypoints, angle, cx, cy)
|
333 |
+
new_keypoints.append(new_kp)
|
334 |
+
target["keypoints"] = torch.stack(new_keypoints)
|
335 |
+
|
336 |
+
return image, target
|
337 |
+
|
338 |
+
def rotate_box(self, box, angle, cx, cy):
|
339 |
+
"""
|
340 |
+
Rotates a bounding box by a given angle around the center of the image.
|
341 |
+
"""
|
342 |
+
x1, y1, x2, y2 = box
|
343 |
+
corners = torch.tensor([
|
344 |
+
[x1, y1],
|
345 |
+
[x2, y1],
|
346 |
+
[x2, y2],
|
347 |
+
[x1, y2]
|
348 |
+
])
|
349 |
+
corners = torch.cat((corners, torch.ones(corners.shape[0], 1)), dim=1)
|
350 |
+
M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
|
351 |
+
corners = torch.matmul(torch.tensor(M, dtype=torch.float32), corners.T).T
|
352 |
+
x_ = corners[:, 0]
|
353 |
+
y_ = corners[:, 1]
|
354 |
+
x_min, x_max = torch.min(x_), torch.max(x_)
|
355 |
+
y_min, y_max = torch.min(y_), torch.max(y_)
|
356 |
+
return torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32)
|
357 |
+
|
358 |
+
def rotate_keypoints(self, keypoints, angle, cx, cy):
|
359 |
+
"""
|
360 |
+
Rotates keypoints by a given angle around the center of the image.
|
361 |
+
"""
|
362 |
+
new_keypoints = []
|
363 |
+
for kp in keypoints:
|
364 |
+
x, y, v = kp
|
365 |
+
point = torch.tensor([x, y, 1])
|
366 |
+
M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
|
367 |
+
new_point = torch.matmul(torch.tensor(M, dtype=torch.float32), point)
|
368 |
+
new_keypoints.append(torch.tensor([new_point[0], new_point[1], v], dtype=torch.float32))
|
369 |
+
return torch.stack(new_keypoints)
|
370 |
+
|
371 |
+
def rotate_90_box(box, angle, w, h):
|
372 |
+
x1, y1, x2, y2 = box
|
373 |
+
if angle == 90:
|
374 |
+
return torch.tensor([y1,h-x2,y2,h-x1])
|
375 |
+
elif angle == 270 or angle == -90:
|
376 |
+
return torch.tensor([w-y2,x1,w-y1,x2])
|
377 |
+
else:
|
378 |
+
print("angle not supported")
|
379 |
+
|
380 |
+
def rotate_90_keypoints(kp, angle, w, h):
|
381 |
+
# Extract coordinates and visibility from each keypoint tensor
|
382 |
+
x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
|
383 |
+
x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
|
384 |
+
# Swap x and y coordinates for each keypoint
|
385 |
+
if angle == 90:
|
386 |
+
new = [[y1, h-x1, v1], [y2, h-x2, v2]]
|
387 |
+
elif angle == 270 or angle == -90:
|
388 |
+
new = [[w-y1, x1, v1], [w-y2, x2, v2]]
|
389 |
+
|
390 |
+
return torch.tensor(new, dtype=torch.float32)
|
391 |
+
|
392 |
+
|
393 |
+
def rotate_vertical(image, target):
|
394 |
+
# Rotate the image and target if the image is vertical
|
395 |
+
new_boxes = []
|
396 |
+
angle = random.choice([-90,90])
|
397 |
+
image = F.rotate(image, angle, expand=True, fill=200)
|
398 |
+
for box in target["boxes"]:
|
399 |
+
new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
|
400 |
+
new_boxes.append(new_box)
|
401 |
+
target["boxes"] = torch.stack(new_boxes)
|
402 |
+
|
403 |
+
if 'keypoints' in target:
|
404 |
+
new_kp = []
|
405 |
+
for kp in target['keypoints']:
|
406 |
+
new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
|
407 |
+
new_kp.append(new_key)
|
408 |
+
target['keypoints'] = torch.stack(new_kp)
|
409 |
+
return image, target
|
410 |
+
|
411 |
+
class BPMN_Dataset(Dataset):
|
412 |
+
def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2, flip_transform=None, rotate_transform=None, new_size=(1333,800),keep_ratio=False,resize=True, model_type='object', rotate_vertical=False):
|
413 |
+
self.annotations = annotations
|
414 |
+
print(f"Loaded {len(self.annotations)} annotations.")
|
415 |
+
self.transform = transform
|
416 |
+
self.crop_transform = crop_transform
|
417 |
+
self.crop_prob = crop_prob
|
418 |
+
self.flip_transform = flip_transform
|
419 |
+
self.rotate_transform = rotate_transform
|
420 |
+
self.resize = resize
|
421 |
+
self.rotate_vertical = rotate_vertical
|
422 |
+
self.new_size = new_size
|
423 |
+
self.keep_ratio = keep_ratio
|
424 |
+
self.model_type = model_type
|
425 |
+
if model_type == 'object':
|
426 |
+
self.dict = object_dict
|
427 |
+
elif model_type == 'arrow':
|
428 |
+
self.dict = arrow_dict
|
429 |
+
self.rotate_90_proba = rotate_90_proba
|
430 |
+
|
431 |
+
def __len__(self):
|
432 |
+
return len(self.annotations)
|
433 |
+
|
434 |
+
def __getitem__(self, idx):
|
435 |
+
annotation = self.annotations[idx]
|
436 |
+
image = annotation.img.convert("RGB")
|
437 |
+
boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
|
438 |
+
labels_names = [ann for ann in annotation.categories]
|
439 |
+
|
440 |
+
#only keep the labels, boxes and keypoints that are in the class_dict
|
441 |
+
kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
|
442 |
+
boxes = boxes[kept_indices]
|
443 |
+
labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
|
444 |
+
|
445 |
+
labels_id = torch.tensor([(list(self.dict.values()).index(ann)) for ann in labels_names], dtype=torch.int64)
|
446 |
+
|
447 |
+
# Initialize keypoints tensor
|
448 |
+
max_keypoints = 2
|
449 |
+
keypoints = torch.zeros((len(labels_id), max_keypoints, 3), dtype=torch.float32)
|
450 |
+
|
451 |
+
ii=0
|
452 |
+
for i, ann in enumerate(annotation.annotations):
|
453 |
+
#only keep the keypoints that are in the kept indices
|
454 |
+
if i not in kept_indices:
|
455 |
+
continue
|
456 |
+
if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
|
457 |
+
# Fill the keypoints tensor for this annotation, mark as visible (1)
|
458 |
+
kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
|
459 |
+
kp = kp[:,:2]
|
460 |
+
visible = np.ones((kp.shape[0], 1), dtype=np.float32)
|
461 |
+
kp = np.hstack([kp, visible])
|
462 |
+
keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
|
463 |
+
ii += 1
|
464 |
+
|
465 |
+
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
466 |
+
|
467 |
+
if self.model_type == 'object':
|
468 |
+
target = {
|
469 |
+
"boxes": boxes,
|
470 |
+
"labels": labels_id,
|
471 |
+
#"area": area,
|
472 |
+
#"keypoints": keypoints,
|
473 |
+
}
|
474 |
+
elif self.model_type == 'arrow':
|
475 |
+
target = {
|
476 |
+
"boxes": boxes,
|
477 |
+
"labels": labels_id,
|
478 |
+
#"area": area,
|
479 |
+
"keypoints": keypoints,
|
480 |
+
}
|
481 |
+
|
482 |
+
# Randomly apply flip transform
|
483 |
+
if self.flip_transform:
|
484 |
+
image, target = self.flip_transform(image, target)
|
485 |
+
|
486 |
+
# Randomly apply rotate transform
|
487 |
+
if self.rotate_transform:
|
488 |
+
image, target = self.rotate_transform(image, target)
|
489 |
+
|
490 |
+
# Randomly apply the custom cropping transform
|
491 |
+
if self.crop_transform and random.random() < self.crop_prob:
|
492 |
+
image, target = self.crop_transform(image, target)
|
493 |
+
|
494 |
+
# Rotate vertical image
|
495 |
+
if self.rotate_vertical and random.random() < self.rotate_90_proba:
|
496 |
+
image, target = rotate_vertical(image, target)
|
497 |
+
|
498 |
+
if self.resize:
|
499 |
+
if self.keep_ratio:
|
500 |
+
original_size = image.size
|
501 |
+
# Calculate scale to fit the new size while maintaining aspect ratio
|
502 |
+
scale = min(self.new_size[0] / original_size[0], self.new_size[1] / original_size[1])
|
503 |
+
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
|
504 |
+
|
505 |
+
target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), (new_scaled_size))
|
506 |
+
if 'area' in target:
|
507 |
+
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
|
508 |
+
|
509 |
+
if 'keypoints' in target:
|
510 |
+
for i in range(len(target['keypoints'])):
|
511 |
+
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), (new_scaled_size))
|
512 |
+
|
513 |
+
# Resize image to new scaled size
|
514 |
+
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
|
515 |
+
|
516 |
+
# Pad the resized image to make it exactly the desired size
|
517 |
+
padding = [0, 0, self.new_size[0] - new_scaled_size[0], self.new_size[1] - new_scaled_size[1]]
|
518 |
+
image = F.pad(image, padding, fill=200, padding_mode='constant')
|
519 |
+
else:
|
520 |
+
target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), self.new_size)
|
521 |
+
if 'area' in target:
|
522 |
+
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
|
523 |
+
if 'keypoints' in target:
|
524 |
+
for i in range(len(target['keypoints'])):
|
525 |
+
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), self.new_size)
|
526 |
+
image = F.resize(image, (self.new_size[1], self.new_size[0]))
|
527 |
+
|
528 |
+
return self.transform(image), target
|
529 |
+
|
530 |
+
def collate_fn(batch):
|
531 |
+
"""
|
532 |
+
Custom collation function for DataLoader that handles batches of images and targets.
|
533 |
+
|
534 |
+
This function ensures that images are properly batched together using PyTorch's default collation,
|
535 |
+
while keeping the targets (such as bounding boxes and labels) in a list of dictionaries,
|
536 |
+
as each image might have a different number of objects detected.
|
537 |
+
|
538 |
+
Parameters:
|
539 |
+
- batch (list): A list of tuples, where each tuple contains an image and its corresponding target dictionary.
|
540 |
+
|
541 |
+
Returns:
|
542 |
+
- Tuple containing:
|
543 |
+
- Tensor: Batched images.
|
544 |
+
- List of dicts: Targets corresponding to each image in the batch.
|
545 |
+
"""
|
546 |
+
images, targets = zip(*batch) # Unzip the batch into separate lists for images and targets.
|
547 |
+
|
548 |
+
# Batch images using the default collate function which handles tensors, numpy arrays, numbers, etc.
|
549 |
+
images = default_collate(images)
|
550 |
+
|
551 |
+
return images, targets
|
552 |
+
|
553 |
+
|
554 |
+
|
555 |
+
def create_loader(new_size,transformation, annotations1, annotations2=None,
|
556 |
+
batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
|
557 |
+
h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
|
558 |
+
seed=42, resize=True, rotate_vertical=False, keep_ratio=False, model_type = 'object'):
|
559 |
+
"""
|
560 |
+
Creates a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
|
561 |
+
|
562 |
+
Parameters:
|
563 |
+
- transformation (callable): Transformation function to apply to each image (e.g., normalization).
|
564 |
+
- annotations1 (list): Primary list of annotations.
|
565 |
+
- annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
|
566 |
+
- batch_size (int): Number of images per batch.
|
567 |
+
- crop_prob (float): Probability of applying the crop transformation.
|
568 |
+
- crop_fraction (float): Fraction of the original width to use when cropping.
|
569 |
+
- min_objects (int): Minimum number of objects required to be within the crop.
|
570 |
+
- h_flip_prob (float): Probability of applying horizontal flip.
|
571 |
+
- v_flip_prob (float): Probability of applying vertical flip.
|
572 |
+
- seed (int): Seed for random number generators for reproducibility.
|
573 |
+
- resize (bool): Flag indicating whether to resize images after transformations.
|
574 |
+
|
575 |
+
Returns:
|
576 |
+
- DataLoader: Configured data loader for the dataset.
|
577 |
+
"""
|
578 |
+
|
579 |
+
# Initialize custom transformations for cropping and flipping
|
580 |
+
custom_crop_transform = RandomCrop(new_size,crop_fraction, min_objects)
|
581 |
+
custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
|
582 |
+
custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
|
583 |
+
|
584 |
+
# Create the primary dataset
|
585 |
+
dataset = BPMN_Dataset(
|
586 |
+
annotations=annotations1,
|
587 |
+
transform=transformation,
|
588 |
+
crop_transform=custom_crop_transform,
|
589 |
+
crop_prob=crop_prob,
|
590 |
+
rotate_90_proba=rotate_90_proba,
|
591 |
+
flip_transform=custom_flip_transform,
|
592 |
+
rotate_transform=custom_rotate_transform,
|
593 |
+
rotate_vertical=rotate_vertical,
|
594 |
+
new_size=new_size,
|
595 |
+
keep_ratio=keep_ratio,
|
596 |
+
model_type=model_type,
|
597 |
+
resize=resize
|
598 |
+
)
|
599 |
+
|
600 |
+
# Optionally concatenate a second dataset
|
601 |
+
if annotations2:
|
602 |
+
dataset2 = BPMN_Dataset(
|
603 |
+
annotations=annotations2,
|
604 |
+
transform=transformation,
|
605 |
+
crop_transform=custom_crop_transform,
|
606 |
+
crop_prob=crop_prob,
|
607 |
+
rotate_90_proba=rotate_90_proba,
|
608 |
+
flip_transform=custom_flip_transform,
|
609 |
+
rotate_vertical=rotate_vertical,
|
610 |
+
new_size=new_size,
|
611 |
+
keep_ratio=keep_ratio,
|
612 |
+
model_type=model_type,
|
613 |
+
resize=resize
|
614 |
+
)
|
615 |
+
dataset = ConcatDataset([dataset, dataset2]) # Concatenate the two datasets
|
616 |
+
|
617 |
+
# Set the seed for reproducibility in random operations within transformations and data loading
|
618 |
+
random.seed(seed)
|
619 |
+
torch.manual_seed(seed)
|
620 |
+
|
621 |
+
# Create the DataLoader with the dataset
|
622 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
|
623 |
+
|
624 |
+
return data_loader
|
625 |
+
|
626 |
+
|
627 |
+
|
628 |
+
def write_results(name_model,metrics_list,start_epoch):
|
629 |
+
with open('./results/'+ name_model+ '.txt', 'w') as f:
|
630 |
+
for i in range(len(metrics_list[0])):
|
631 |
+
f.write(f"{i+1+start_epoch},{metrics_list[0][i]},{metrics_list[1][i]},{metrics_list[2][i]},{metrics_list[3][i]},{metrics_list[4][i]},{metrics_list[5][i]},{metrics_list[6][i]},{metrics_list[7][i]},{metrics_list[8][i]},{metrics_list[9][i]} \n")
|
632 |
+
|
633 |
+
|
634 |
+
def find_other_keypoint(idx, keypoints, boxes):
|
635 |
+
box = boxes[idx]
|
636 |
+
key1,key2 = keypoints[idx]
|
637 |
+
x1, y1, x2, y2 = box
|
638 |
+
center = ((x1 + x2) // 2, (y1 + y2) // 2)
|
639 |
+
average_keypoint = (key1 + key2) // 2
|
640 |
+
#find the opposite keypoint to the center
|
641 |
+
if average_keypoint[0] < center[0]:
|
642 |
+
x = center[0] + abs(center[0] - average_keypoint[0])
|
643 |
+
else:
|
644 |
+
x = center[0] - abs(center[0] - average_keypoint[0])
|
645 |
+
if average_keypoint[1] < center[1]:
|
646 |
+
y = center[1] + abs(center[1] - average_keypoint[1])
|
647 |
+
else:
|
648 |
+
y = center[1] - abs(center[1] - average_keypoint[1])
|
649 |
+
return x, y, average_keypoint[0], average_keypoint[1]
|
650 |
+
|
651 |
+
|
652 |
+
def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
|
653 |
+
"""
|
654 |
+
Filters overlapping boxes based on the Intersection over Union (IoU) metric, keeping only the boxes with the highest scores.
|
655 |
+
|
656 |
+
Parameters:
|
657 |
+
- boxes (np.ndarray): Array of bounding boxes with shape (N, 4), where each row contains [x_min, y_min, x_max, y_max].
|
658 |
+
- scores (np.ndarray): Array of scores for each box, reflecting the confidence of detection.
|
659 |
+
- labels (np.ndarray): Array of labels corresponding to each box.
|
660 |
+
- keypoints (np.ndarray): Array of keypoints associated with each box.
|
661 |
+
- iou_threshold (float): Threshold for IoU above which a box is considered overlapping.
|
662 |
+
|
663 |
+
Returns:
|
664 |
+
- tuple: Filtered boxes, scores, labels, and keypoints.
|
665 |
+
"""
|
666 |
+
# Calculate the area of each bounding box to use in IoU calculation.
|
667 |
+
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
668 |
+
|
669 |
+
# Sort the indices of the boxes based on their scores in descending order.
|
670 |
+
order = scores.argsort()[::-1]
|
671 |
+
|
672 |
+
keep = [] # List to store indices of boxes to keep.
|
673 |
+
|
674 |
+
while order.size > 0:
|
675 |
+
# Take the first index (highest score) from the sorted list.
|
676 |
+
i = order[0]
|
677 |
+
keep.append(i) # Add this index to 'keep' list.
|
678 |
+
|
679 |
+
# Compute the coordinates of the intersection rectangle.
|
680 |
+
xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
|
681 |
+
yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
|
682 |
+
xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
|
683 |
+
yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
|
684 |
+
|
685 |
+
# Compute the area of the intersection rectangle.
|
686 |
+
w = np.maximum(0.0, xx2 - xx1)
|
687 |
+
h = np.maximum(0.0, yy2 - yy1)
|
688 |
+
inter = w * h
|
689 |
+
|
690 |
+
# Calculate IoU and find boxes with IoU less than the threshold to keep.
|
691 |
+
iou = inter / (areas[i] + areas[order[1:]] - inter)
|
692 |
+
inds = np.where(iou <= iou_threshold)[0]
|
693 |
+
|
694 |
+
# Update the list of box indices to consider in the next iteration.
|
695 |
+
order = order[inds + 1] # Skip the first element since it's already included in 'keep'.
|
696 |
+
|
697 |
+
# Use the indices in 'keep' to select the boxes, scores, labels, and keypoints to return.
|
698 |
+
boxes = boxes[keep]
|
699 |
+
scores = scores[keep]
|
700 |
+
labels = labels[keep]
|
701 |
+
keypoints = keypoints[keep]
|
702 |
+
|
703 |
+
return boxes, scores, labels, keypoints
|
704 |
+
|
705 |
+
|
706 |
+
|
707 |
+
def draw_annotations(image,
|
708 |
+
target=None,
|
709 |
+
prediction=None,
|
710 |
+
full_prediction=None,
|
711 |
+
text_predictions=None,
|
712 |
+
model_dict=class_dict,
|
713 |
+
draw_keypoints=False,
|
714 |
+
draw_boxes=False,
|
715 |
+
draw_text=False,
|
716 |
+
draw_links=False,
|
717 |
+
draw_twins=False,
|
718 |
+
write_class=False,
|
719 |
+
write_score=False,
|
720 |
+
write_text=False,
|
721 |
+
write_idx=False,
|
722 |
+
score_threshold=0.4,
|
723 |
+
keypoints_correction=False,
|
724 |
+
only_print=None,
|
725 |
+
axis=False,
|
726 |
+
return_image=False,
|
727 |
+
new_size=(1333,800),
|
728 |
+
resize=False):
|
729 |
+
"""
|
730 |
+
Draws annotations on images including bounding boxes, keypoints, links, and text.
|
731 |
+
|
732 |
+
Parameters:
|
733 |
+
- image (np.array): The image on which annotations will be drawn.
|
734 |
+
- target (dict): Ground truth data containing boxes, labels, etc.
|
735 |
+
- prediction (dict): Prediction data from a model.
|
736 |
+
- full_prediction (dict): Additional detailed prediction data, potentially including relationships.
|
737 |
+
- text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
|
738 |
+
- model_dict (dict): Mapping from class IDs to class names.
|
739 |
+
- draw_keypoints (bool): Flag to draw keypoints.
|
740 |
+
- draw_boxes (bool): Flag to draw bounding boxes.
|
741 |
+
- draw_text (bool): Flag to draw text annotations.
|
742 |
+
- draw_links (bool): Flag to draw links between annotations.
|
743 |
+
- draw_twins (bool): Flag to draw twins keypoints.
|
744 |
+
- write_class (bool): Flag to write class names near the annotations.
|
745 |
+
- write_score (bool): Flag to write scores near the annotations.
|
746 |
+
- write_text (bool): Flag to write OCR recognized text.
|
747 |
+
- score_threshold (float): Threshold for scores above which annotations will be drawn.
|
748 |
+
- only_print (str): Specific class name to filter annotations by.
|
749 |
+
- resize (bool): Whether to resize annotations to fit the image size.
|
750 |
+
"""
|
751 |
+
|
752 |
+
# Convert image to RGB (if not already in that format)
|
753 |
+
if prediction is None:
|
754 |
+
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
755 |
+
|
756 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
757 |
+
image_copy = image.copy()
|
758 |
+
scale = max(image.shape[0], image.shape[1]) / 1000
|
759 |
+
|
760 |
+
# Function to draw bounding boxes and keypoints
|
761 |
+
def draw(data,is_prediction=False):
|
762 |
+
""" Helper function to draw annotations based on provided data. """
|
763 |
+
|
764 |
+
for i in range(len(data['boxes'])):
|
765 |
+
if is_prediction:
|
766 |
+
box = data['boxes'][i].tolist()
|
767 |
+
x1, y1, x2, y2 = box
|
768 |
+
if resize:
|
769 |
+
x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
|
770 |
+
score = data['scores'][i].item()
|
771 |
+
if score < score_threshold:
|
772 |
+
continue
|
773 |
+
else:
|
774 |
+
box = data['boxes'][i].tolist()
|
775 |
+
x1, y1, x2, y2 = box
|
776 |
+
if draw_boxes:
|
777 |
+
if only_print is not None:
|
778 |
+
if data['labels'][i] != list(model_dict.values()).index(only_print):
|
779 |
+
continue
|
780 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2*scale))
|
781 |
+
if is_prediction and write_score:
|
782 |
+
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
|
783 |
+
|
784 |
+
if write_class and 'labels' in data:
|
785 |
+
class_id = data['labels'][i].item()
|
786 |
+
cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
|
787 |
+
|
788 |
+
if write_idx:
|
789 |
+
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
|
790 |
+
|
791 |
+
|
792 |
+
# Draw keypoints if available
|
793 |
+
if draw_keypoints and 'keypoints' in data:
|
794 |
+
if is_prediction and keypoints_correction:
|
795 |
+
for idx, (key1, key2) in enumerate(data['keypoints']):
|
796 |
+
if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
797 |
+
list(model_dict.values()).index('messageFlow'),
|
798 |
+
list(model_dict.values()).index('dataAssociation')]:
|
799 |
+
continue
|
800 |
+
# Calculate the Euclidean distance between the two keypoints
|
801 |
+
distance = np.linalg.norm(key1[:2] - key2[:2])
|
802 |
+
|
803 |
+
if distance < 5:
|
804 |
+
x_new,y_new, x,y = find_other_keypoint(idx, data['keypoints'], data['boxes'])
|
805 |
+
data['keypoints'][idx][0] = torch.tensor([x_new, y_new,1])
|
806 |
+
data['keypoints'][idx][1] = torch.tensor([x, y,1])
|
807 |
+
print("keypoint has been changed")
|
808 |
+
for i in range(len(data['keypoints'])):
|
809 |
+
kp = data['keypoints'][i]
|
810 |
+
for j in range(kp.shape[0]):
|
811 |
+
if is_prediction and data['labels'][i] != list(model_dict.values()).index('sequenceFlow') and data['labels'][i] != list(model_dict.values()).index('messageFlow') and data['labels'][i] != list(model_dict.values()).index('dataAssociation'):
|
812 |
+
continue
|
813 |
+
if is_prediction:
|
814 |
+
score = data['scores'][i]
|
815 |
+
if score < score_threshold:
|
816 |
+
continue
|
817 |
+
x,y,v = np.array(kp[j])
|
818 |
+
if resize:
|
819 |
+
x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
|
820 |
+
if j == 0:
|
821 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
|
822 |
+
else:
|
823 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
|
824 |
+
|
825 |
+
# Draw text predictions if available
|
826 |
+
if (draw_text or write_text) and text_predictions is not None:
|
827 |
+
for i in range(len(text_predictions[0])):
|
828 |
+
x1, y1, x2, y2 = text_predictions[0][i]
|
829 |
+
text = text_predictions[1][i]
|
830 |
+
if resize:
|
831 |
+
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
|
832 |
+
if draw_text:
|
833 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
|
834 |
+
if write_text:
|
835 |
+
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
|
836 |
+
|
837 |
+
def draw_with_links(full_prediction):
|
838 |
+
'''Draws links between objects based on the full prediction data.'''
|
839 |
+
#check if keypoints detected are the same
|
840 |
+
if draw_twins and full_prediction is not None:
|
841 |
+
# Pre-calculate indices for performance
|
842 |
+
circle_color = (0, 255, 0) # Green color for the circle
|
843 |
+
circle_radius = int(10 * scale) # Circle radius scaled by image scale
|
844 |
+
|
845 |
+
for idx, (key1, key2) in enumerate(full_prediction['keypoints']):
|
846 |
+
if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
847 |
+
list(model_dict.values()).index('messageFlow'),
|
848 |
+
list(model_dict.values()).index('dataAssociation')]:
|
849 |
+
continue
|
850 |
+
# Calculate the Euclidean distance between the two keypoints
|
851 |
+
distance = np.linalg.norm(key1[:2] - key2[:2])
|
852 |
+
if distance < 10:
|
853 |
+
x_new,y_new, x,y = find_other_keypoint(idx,full_prediction)
|
854 |
+
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
|
855 |
+
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
|
856 |
+
|
857 |
+
# Draw links between objects
|
858 |
+
if draw_links==True and full_prediction is not None:
|
859 |
+
for i, (start_idx, end_idx) in enumerate(full_prediction['links']):
|
860 |
+
if start_idx is None or end_idx is None:
|
861 |
+
continue
|
862 |
+
start_box = full_prediction['boxes'][start_idx]
|
863 |
+
end_box = full_prediction['boxes'][end_idx]
|
864 |
+
current_box = full_prediction['boxes'][i]
|
865 |
+
# Calculate the center of each bounding box
|
866 |
+
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
|
867 |
+
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
|
868 |
+
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
|
869 |
+
# Draw a line between the centers of the connected objects
|
870 |
+
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
|
871 |
+
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
|
872 |
+
|
873 |
+
i+=1
|
874 |
+
|
875 |
+
# Draw GT annotations
|
876 |
+
if target is not None:
|
877 |
+
draw(target, is_prediction=False)
|
878 |
+
# Draw predictions
|
879 |
+
if prediction is not None:
|
880 |
+
#prediction = prediction[0]
|
881 |
+
draw(prediction, is_prediction=True)
|
882 |
+
# Draw links with full predictions
|
883 |
+
if full_prediction is not None:
|
884 |
+
draw_with_links(full_prediction)
|
885 |
+
|
886 |
+
# Display the image
|
887 |
+
image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
|
888 |
+
plt.figure(figsize=(12, 12))
|
889 |
+
plt.imshow(image_copy)
|
890 |
+
if axis==False:
|
891 |
+
plt.axis('off')
|
892 |
+
plt.show()
|
893 |
+
|
894 |
+
if return_image:
|
895 |
+
return image_copy
|
896 |
+
|
897 |
+
def find_closest_object(keypoint, boxes, labels):
|
898 |
+
"""
|
899 |
+
Find the closest object to a keypoint based on their proximity.
|
900 |
+
|
901 |
+
Parameters:
|
902 |
+
- keypoint (numpy.ndarray): The coordinates of the keypoint.
|
903 |
+
- boxes (numpy.ndarray): The bounding boxes of the objects.
|
904 |
+
|
905 |
+
Returns:
|
906 |
+
- int or None: The index of the closest object to the keypoint, or None if no object is found.
|
907 |
+
"""
|
908 |
+
min_distance = float('inf')
|
909 |
+
closest_object_idx = None
|
910 |
+
# Iterate over each bounding box
|
911 |
+
for i, box in enumerate(boxes):
|
912 |
+
if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
|
913 |
+
list(class_dict.values()).index('messageFlow'),
|
914 |
+
list(class_dict.values()).index('dataAssociation'),
|
915 |
+
#list(class_dict.values()).index('pool'),
|
916 |
+
list(class_dict.values()).index('lane')]:
|
917 |
+
continue
|
918 |
+
x1, y1, x2, y2 = box
|
919 |
+
|
920 |
+
top = ((x1+x2)/2, y1)
|
921 |
+
bottom = ((x1+x2)/2, y2)
|
922 |
+
left = (x1, (y1+y2)/2)
|
923 |
+
right = (x2, (y1+y2)/2)
|
924 |
+
points = [left, top , right, bottom]
|
925 |
+
|
926 |
+
# Calculate the distance between the keypoint and the center of the bounding box
|
927 |
+
for point in points:
|
928 |
+
distance = np.linalg.norm(keypoint[:2] - point)
|
929 |
+
# Update the closest object index if this object is closer
|
930 |
+
if distance < min_distance:
|
931 |
+
min_distance = distance
|
932 |
+
closest_object_idx = i
|
933 |
+
best_point = point
|
934 |
+
|
935 |
+
return closest_object_idx, best_point
|
936 |
+
|