Spaces:
Running
Running
add commentary to all the code
Browse files- app.py +26 -11
- modules/OCR.py +143 -41
- modules/dataset_loader.py +166 -60
- modules/eval.py +391 -94
- modules/streamlit_utils.py +209 -48
- modules/toWizard.py +95 -14
- modules/toXML.py +330 -64
- modules/train.py +265 -240
- modules/utils.py +79 -196
app.py
CHANGED
@@ -6,75 +6,90 @@ import numpy as np
|
|
6 |
from modules.streamlit_utils import *
|
7 |
from modules.utils import error
|
8 |
|
9 |
-
|
10 |
def main():
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
if 'model_loaded' not in st.session_state:
|
13 |
st.session_state.model_loaded = False
|
14 |
|
15 |
st.session_state.first_run = True
|
|
|
|
|
16 |
is_mobile, screen_width = configure_page()
|
|
|
|
|
17 |
display_banner(is_mobile)
|
18 |
display_title(is_mobile)
|
19 |
display_sidebar()
|
20 |
|
|
|
21 |
initialize_session_state()
|
22 |
|
23 |
cropped_image = None
|
24 |
|
|
|
25 |
img_selected = load_example_image()
|
26 |
uploaded_file = load_user_image(img_selected, is_mobile)
|
27 |
|
|
|
28 |
if uploaded_file is not None:
|
29 |
cropped_image = display_image(uploaded_file, screen_width, is_mobile)
|
30 |
|
|
|
31 |
if uploaded_file is not None:
|
32 |
get_score_threshold(is_mobile)
|
33 |
|
|
|
34 |
if st.button("🚀 Launch Prediction"):
|
35 |
st.session_state.image = launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
|
36 |
st.session_state.original_prediction = st.session_state.prediction.copy()
|
37 |
st.rerun()
|
38 |
|
39 |
-
# Create placeholders for
|
40 |
prediction_result_placeholder = st.empty()
|
41 |
additional_options_placeholder = st.empty()
|
42 |
modeler_placeholder = st.empty()
|
43 |
|
44 |
-
|
45 |
if 'prediction' in st.session_state and uploaded_file:
|
46 |
if st.session_state.image != cropped_image:
|
47 |
print('Image has changed')
|
48 |
-
# Delete the prediction
|
49 |
del st.session_state.prediction
|
50 |
return
|
51 |
|
52 |
-
if len(st.session_state.prediction['labels'])==0:
|
53 |
-
error("No prediction available. Please upload a BPMN image or decrease the detection score
|
54 |
else:
|
55 |
with prediction_result_placeholder.container():
|
56 |
if is_mobile:
|
57 |
-
display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6*screen_width))
|
58 |
else:
|
59 |
with st.expander("Show result of prediction"):
|
60 |
-
display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6*screen_width))
|
61 |
|
|
|
62 |
if not is_mobile:
|
63 |
with additional_options_placeholder.container():
|
64 |
state = modify_results()
|
65 |
|
66 |
-
|
67 |
with modeler_placeholder.container():
|
68 |
modeler_options(is_mobile)
|
69 |
display_bpmn_modeler(is_mobile, screen_width)
|
70 |
else:
|
|
|
71 |
prediction_result_placeholder.empty()
|
72 |
additional_options_placeholder.empty()
|
73 |
modeler_placeholder.empty()
|
74 |
-
# Create
|
75 |
for _ in range(50):
|
76 |
st.text("")
|
77 |
|
|
|
78 |
gc.collect()
|
79 |
|
80 |
if __name__ == "__main__":
|
|
|
6 |
from modules.streamlit_utils import *
|
7 |
from modules.utils import error
|
8 |
|
|
|
9 |
def main():
|
10 |
+
"""
|
11 |
+
Main function to run the Streamlit application for BPMN AI model recognition.
|
12 |
+
"""
|
13 |
+
|
14 |
+
# Check if the model is loaded in the session state
|
15 |
if 'model_loaded' not in st.session_state:
|
16 |
st.session_state.model_loaded = False
|
17 |
|
18 |
st.session_state.first_run = True
|
19 |
+
|
20 |
+
# Configure the Streamlit page and retrieve screen details
|
21 |
is_mobile, screen_width = configure_page()
|
22 |
+
|
23 |
+
# Display various UI components
|
24 |
display_banner(is_mobile)
|
25 |
display_title(is_mobile)
|
26 |
display_sidebar()
|
27 |
|
28 |
+
# Initialize session state variables
|
29 |
initialize_session_state()
|
30 |
|
31 |
cropped_image = None
|
32 |
|
33 |
+
# Load example or user-uploaded image
|
34 |
img_selected = load_example_image()
|
35 |
uploaded_file = load_user_image(img_selected, is_mobile)
|
36 |
|
37 |
+
# Display the uploaded image and allow cropping
|
38 |
if uploaded_file is not None:
|
39 |
cropped_image = display_image(uploaded_file, screen_width, is_mobile)
|
40 |
|
41 |
+
# Set score threshold for prediction if an image is uploaded
|
42 |
if uploaded_file is not None:
|
43 |
get_score_threshold(is_mobile)
|
44 |
|
45 |
+
# Launch prediction when the button is clicked
|
46 |
if st.button("🚀 Launch Prediction"):
|
47 |
st.session_state.image = launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
|
48 |
st.session_state.original_prediction = st.session_state.prediction.copy()
|
49 |
st.rerun()
|
50 |
|
51 |
+
# Create placeholders for different sections of the UI
|
52 |
prediction_result_placeholder = st.empty()
|
53 |
additional_options_placeholder = st.empty()
|
54 |
modeler_placeholder = st.empty()
|
55 |
|
56 |
+
# Display prediction results and options if predictions are available
|
57 |
if 'prediction' in st.session_state and uploaded_file:
|
58 |
if st.session_state.image != cropped_image:
|
59 |
print('Image has changed')
|
60 |
+
# Delete the prediction if the image has changed
|
61 |
del st.session_state.prediction
|
62 |
return
|
63 |
|
64 |
+
if len(st.session_state.prediction['labels']) == 0:
|
65 |
+
error("No prediction available. Please upload a BPMN image or decrease the detection score threshold.")
|
66 |
else:
|
67 |
with prediction_result_placeholder.container():
|
68 |
if is_mobile:
|
69 |
+
display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
|
70 |
else:
|
71 |
with st.expander("Show result of prediction"):
|
72 |
+
display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
|
73 |
|
74 |
+
# Provide additional options for modification if not on mobile
|
75 |
if not is_mobile:
|
76 |
with additional_options_placeholder.container():
|
77 |
state = modify_results()
|
78 |
|
79 |
+
# Display BPMN modeler options and result
|
80 |
with modeler_placeholder.container():
|
81 |
modeler_options(is_mobile)
|
82 |
display_bpmn_modeler(is_mobile, screen_width)
|
83 |
else:
|
84 |
+
# Clear placeholders if no predictions are available
|
85 |
prediction_result_placeholder.empty()
|
86 |
additional_options_placeholder.empty()
|
87 |
modeler_placeholder.empty()
|
88 |
+
# Create space for scrolling
|
89 |
for _ in range(50):
|
90 |
st.text("")
|
91 |
|
92 |
+
# Force garbage collection
|
93 |
gc.collect()
|
94 |
|
95 |
if __name__ == "__main__":
|
modules/OCR.py
CHANGED
@@ -3,13 +3,14 @@ import os
|
|
3 |
from azure.ai.vision.imageanalysis import ImageAnalysisClient
|
4 |
from azure.ai.vision.imageanalysis.models import VisualFeatures
|
5 |
from azure.core.credentials import AzureKeyCredential
|
6 |
-
import time
|
7 |
import numpy as np
|
8 |
import networkx as nx
|
9 |
from modules.utils import class_dict, proportion_inside
|
10 |
import json
|
11 |
from modules.utils import rescale_boxes as rescale, is_vertical
|
12 |
-
import
|
|
|
|
|
13 |
|
14 |
VISION_KEY = os.getenv("VISION_KEY")
|
15 |
VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
|
@@ -20,15 +21,17 @@ VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
|
|
20 |
|
21 |
VISION_KEY = json_data["VISION_KEY"]
|
22 |
VISION_ENDPOINT = json_data["VISION_ENDPOINT"]"""
|
23 |
-
|
24 |
-
|
25 |
-
import logging
|
26 |
|
27 |
# Suppress specific warnings from transformers
|
28 |
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
29 |
|
30 |
# Function to initialize the model and tokenizer
|
31 |
def initialize_model():
|
|
|
|
|
|
|
32 |
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
|
33 |
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
|
34 |
return tokenizer, model
|
@@ -38,6 +41,17 @@ tokenizer, emotion_model = initialize_model()
|
|
38 |
|
39 |
# Function to perform sentiment analysis and return the highest scoring emotion and its score between positive and negative
|
40 |
def analyze_sentiment(sentence, tokenizer=tokenizer, model=emotion_model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
inputs = tokenizer(sentence, return_tensors="pt")
|
42 |
outputs = model(**inputs)
|
43 |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
|
@@ -51,8 +65,16 @@ def analyze_sentiment(sentence, tokenizer=tokenizer, model=emotion_model):
|
|
51 |
return highest_emotion, highest_score
|
52 |
|
53 |
def sample_ocr_image_file(image_data):
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
try:
|
57 |
endpoint = VISION_ENDPOINT
|
58 |
key = VISION_KEY
|
@@ -77,16 +99,35 @@ def sample_ocr_image_file(image_data):
|
|
77 |
|
78 |
|
79 |
def text_prediction(image):
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
image.save('temp.jpg')
|
82 |
with open('temp.jpg', 'rb') as f:
|
83 |
image_data = f.read()
|
84 |
ocr_result = sample_ocr_image_file(image_data)
|
85 |
-
#
|
86 |
os.remove('temp.jpg')
|
87 |
return ocr_result
|
88 |
|
89 |
def filter_text(ocr_result, threshold=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
words_to_cancel = {"-","--","---","+",".",",","#","@","!","?","(",")","[","]","{","}","<",">","/","\\","|","-","_","=","&","^","%","$","£","€","¥","¢","¤","§","©","®","™","°","±","×","÷","¶","∆","∏","∑","∞","√","∫","≈","≠","≤","≥","≡","∼"}
|
91 |
# Add every other one-letter word to the list of words to cancel, except 'I' and 'a'
|
92 |
for letter in "bcdefghjklmnopqrstuvwxyz1234567890": # All lowercase letters except 'a'
|
@@ -132,10 +173,16 @@ def filter_text(ocr_result, threshold=0.5):
|
|
132 |
return list_of_lines
|
133 |
|
134 |
|
|
|
|
|
|
|
135 |
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
|
|
139 |
xmin, ymin, xmax, ymax = box
|
140 |
return np.array([
|
141 |
[xmin, ymin], # Bottom-left corner
|
@@ -149,7 +196,16 @@ def get_box_points(box):
|
|
149 |
])
|
150 |
|
151 |
def min_distance_between_boxes(box1, box2):
|
152 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
points1 = get_box_points(box1)
|
154 |
points2 = get_box_points(box2)
|
155 |
|
@@ -162,7 +218,17 @@ def min_distance_between_boxes(box1, box2):
|
|
162 |
return min_dist
|
163 |
|
164 |
def are_close(box1, box2, threshold=50):
|
165 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
corners1 = np.array([
|
167 |
[box1[0], box1[1]], [box1[0], box1[3]], [box1[2], box1[1]], [box1[2], box1[3]],
|
168 |
[(box1[0]+box1[2])/2, box1[1]], [(box1[0]+box1[2])/2, box1[3]],
|
@@ -180,13 +246,25 @@ def are_close(box1, box2, threshold=50):
|
|
180 |
return False
|
181 |
|
182 |
def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
|
183 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
min_distance = float('inf')
|
185 |
closest_index = None
|
186 |
|
187 |
-
#
|
188 |
for j in range(len(all_boxes)):
|
189 |
-
if proportion_inside(text_box, all_boxes[j])>iou_threshold and labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
190 |
return j
|
191 |
|
192 |
for i, box in enumerate(all_boxes):
|
@@ -209,20 +287,32 @@ def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
|
|
209 |
return None
|
210 |
|
211 |
|
212 |
-
|
213 |
def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, percentage_thresh=0.8):
|
214 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
G = nx.Graph()
|
216 |
|
217 |
# Map each text box to the nearest task box
|
218 |
task_to_texts = {i: [] for i in range(len(task_boxes))}
|
219 |
-
information_texts = [] #
|
220 |
text_to_task_mapped = [False] * len(text_boxes)
|
221 |
|
222 |
for idx, text_box in enumerate(text_boxes):
|
223 |
mapped = False
|
224 |
for jdx, task_box in enumerate(task_boxes):
|
225 |
-
if proportion_inside(text_box, task_box)>iou_threshold:
|
226 |
task_to_texts[jdx].append(idx)
|
227 |
text_to_task_mapped[idx] = True
|
228 |
mapped = True
|
@@ -326,32 +416,45 @@ def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, p
|
|
326 |
return all_grouped_texts, sentence_boxes, information_grouped_texts, info_sentence_boxes
|
327 |
|
328 |
|
329 |
-
def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0.6,scale=1.0, iou_threshold=0.5):
|
|
|
|
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
boxes = rescale(scale, full_pred['boxes'])
|
332 |
|
333 |
min_dist = 200
|
334 |
labels = full_pred['labels']
|
335 |
avoid = [list(class_dict.values()).index('pool'), list(class_dict.values()).index('lane'), list(class_dict.values()).index('sequenceFlow'), list(class_dict.values()).index('messageFlow'), list(class_dict.values()).index('dataAssociation')]
|
336 |
for i in range(len(boxes)):
|
337 |
-
|
338 |
-
|
|
|
|
|
|
|
|
|
339 |
continue
|
340 |
-
|
341 |
-
|
342 |
-
if labels[j] in avoid:
|
343 |
-
continue
|
344 |
-
dist = min_distance_between_boxes(box1, box2)
|
345 |
-
min_dist = min(min_dist, dist)
|
346 |
|
347 |
-
#
|
|
|
348 |
|
349 |
text_pred[0] = rescale(scale, text_pred[0])
|
350 |
task_boxes = [box for i, box in enumerate(boxes) if full_pred['labels'][i] == list(class_dict.values()).index('task')]
|
351 |
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_pred[0], text_pred[1], min_dist=min_dist)
|
352 |
BPMN_id = set(full_pred['BPMN_id']) # This ensures uniqueness of task names
|
353 |
text_mapping = {id: '' for id in BPMN_id}
|
354 |
-
|
355 |
|
356 |
if print_sentences:
|
357 |
for sentence, box in zip(grouped_sentences, sentence_bounding_boxes):
|
@@ -363,8 +466,8 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
|
|
363 |
# Map the grouped sentences to the corresponding task
|
364 |
for i in range(len(sentence_bounding_boxes)):
|
365 |
for j in range(len(boxes)):
|
366 |
-
if proportion_inside(sentence_bounding_boxes[i], boxes[j])>iou_threshold and full_pred['labels'][j] == list(class_dict.values()).index('task'):
|
367 |
-
text_mapping[full_pred['BPMN_id'][j]]=grouped_sentences[i]
|
368 |
|
369 |
# Map the grouped sentences to the corresponding pool
|
370 |
for key, elements in full_pred['pool_dict'].items():
|
@@ -372,17 +475,16 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
|
|
372 |
continue
|
373 |
else:
|
374 |
for i in range(len(info_boxes)):
|
375 |
-
#
|
376 |
position = list(full_pred['BPMN_id']).index(key)
|
377 |
-
if proportion_inside(info_boxes[i], boxes[position])>iou_threshold:
|
378 |
text_mapping[key] = info_texts[i]
|
379 |
info_texts[i] = '' # Clear the text to avoid re-use
|
380 |
|
381 |
-
|
382 |
for i in range(len(info_boxes)):
|
383 |
if is_vertical(info_boxes[i]):
|
384 |
for j in range(len(boxes)):
|
385 |
-
if proportion_inside(info_boxes[i], boxes[j])>0 and full_pred['labels'][j] == list(class_dict.values()).index('pool'):
|
386 |
print("Text:", info_texts[i], "associate with ", full_pred['BPMN_id'][j])
|
387 |
bpmn_id = full_pred['BPMN_id'][j]
|
388 |
# Append new text or create new entry if not existing
|
@@ -399,10 +501,10 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
|
|
399 |
for j in range(len(boxes)):
|
400 |
if info_texts[i] == '':
|
401 |
continue # Skip if there's no text
|
402 |
-
if (proportion_inside(info_boxes[i], boxes[j])>0 or are_close(info_boxes[i], boxes[j], threshold=percentage_thresh*min_dist)) and (full_pred['labels'][j] == list(class_dict.values()).index('event')
|
403 |
or full_pred['labels'][j] == list(class_dict.values()).index('messageEvent')
|
404 |
or full_pred['labels'][j] == list(class_dict.values()).index('timerEvent')
|
405 |
-
or full_pred['labels'][j] == list(class_dict.values()).index('dataObject'))
|
406 |
bpmn_id = full_pred['BPMN_id'][j]
|
407 |
# Append new text or create new entry if not existing
|
408 |
if bpmn_id in text_mapping:
|
@@ -416,7 +518,7 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
|
|
416 |
if info_texts[i] == '' or is_vertical(info_boxes[i]):
|
417 |
continue # Skip if there's no text
|
418 |
# Find the closest box within the defined threshold
|
419 |
-
closest_index = find_closest_box(info_boxes[i], boxes, full_pred['labels'], threshold=4*min_dist)
|
420 |
if closest_index is not None and (full_pred['labels'][closest_index] == list(class_dict.values()).index('sequenceFlow') or full_pred['labels'][closest_index] == list(class_dict.values()).index('messageFlow')):
|
421 |
bpmn_id = full_pred['BPMN_id'][closest_index]
|
422 |
# Append new text or create new entry if not existing
|
@@ -430,4 +532,4 @@ def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0
|
|
430 |
print("Text Mapping:", text_mapping)
|
431 |
print("Information Texts left:", info_texts)
|
432 |
|
433 |
-
return text_mapping
|
|
|
3 |
from azure.ai.vision.imageanalysis import ImageAnalysisClient
|
4 |
from azure.ai.vision.imageanalysis.models import VisualFeatures
|
5 |
from azure.core.credentials import AzureKeyCredential
|
|
|
6 |
import numpy as np
|
7 |
import networkx as nx
|
8 |
from modules.utils import class_dict, proportion_inside
|
9 |
import json
|
10 |
from modules.utils import rescale_boxes as rescale, is_vertical
|
11 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
12 |
+
import torch
|
13 |
+
import logging
|
14 |
|
15 |
VISION_KEY = os.getenv("VISION_KEY")
|
16 |
VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
|
|
|
21 |
|
22 |
VISION_KEY = json_data["VISION_KEY"]
|
23 |
VISION_ENDPOINT = json_data["VISION_ENDPOINT"]"""
|
24 |
+
|
25 |
+
|
|
|
26 |
|
27 |
# Suppress specific warnings from transformers
|
28 |
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
29 |
|
30 |
# Function to initialize the model and tokenizer
|
31 |
def initialize_model():
|
32 |
+
"""
|
33 |
+
Initialize the tokenizer and model for sentiment analysis.
|
34 |
+
"""
|
35 |
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
|
36 |
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
|
37 |
return tokenizer, model
|
|
|
41 |
|
42 |
# Function to perform sentiment analysis and return the highest scoring emotion and its score between positive and negative
|
43 |
def analyze_sentiment(sentence, tokenizer=tokenizer, model=emotion_model):
|
44 |
+
"""
|
45 |
+
Analyze the sentiment of a given sentence using the initialized tokenizer and model.
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
- sentence (str): The input sentence to analyze.
|
49 |
+
- tokenizer (AutoTokenizer): The tokenizer for processing the sentence.
|
50 |
+
- model (AutoModelForSequenceClassification): The model for sentiment analysis.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
- tuple: The highest scoring emotion ('positive' or 'negative') and its corresponding score.
|
54 |
+
"""
|
55 |
inputs = tokenizer(sentence, return_tensors="pt")
|
56 |
outputs = model(**inputs)
|
57 |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
|
|
|
65 |
return highest_emotion, highest_score
|
66 |
|
67 |
def sample_ocr_image_file(image_data):
|
68 |
+
"""
|
69 |
+
Sample OCR function to analyze an image file and extract text using Azure's Computer Vision service.
|
70 |
+
|
71 |
+
Parameters:
|
72 |
+
- image_data (bytes): The image data in bytes.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
- result: The OCR result from the Computer Vision service.
|
76 |
+
"""
|
77 |
+
# Set the values of your computer vision endpoint and computer vision key as environment variables:
|
78 |
try:
|
79 |
endpoint = VISION_ENDPOINT
|
80 |
key = VISION_KEY
|
|
|
99 |
|
100 |
|
101 |
def text_prediction(image):
|
102 |
+
"""
|
103 |
+
Perform OCR on an image to extract text.
|
104 |
+
|
105 |
+
Parameters:
|
106 |
+
- image: The image to process.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
- ocr_result: The OCR result.
|
110 |
+
"""
|
111 |
+
# Transform the image into a byte array
|
112 |
image.save('temp.jpg')
|
113 |
with open('temp.jpg', 'rb') as f:
|
114 |
image_data = f.read()
|
115 |
ocr_result = sample_ocr_image_file(image_data)
|
116 |
+
# Delete the temporary image
|
117 |
os.remove('temp.jpg')
|
118 |
return ocr_result
|
119 |
|
120 |
def filter_text(ocr_result, threshold=0.5):
|
121 |
+
"""
|
122 |
+
Filter and process the OCR results to remove unwanted characters and low-confidence words.
|
123 |
+
|
124 |
+
Parameters:
|
125 |
+
- ocr_result: The OCR result.
|
126 |
+
- threshold (float): The confidence threshold for filtering words.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
- list_of_lines: Processed text lines and their bounding boxes.
|
130 |
+
"""
|
131 |
words_to_cancel = {"-","--","---","+",".",",","#","@","!","?","(",")","[","]","{","}","<",">","/","\\","|","-","_","=","&","^","%","$","£","€","¥","¢","¤","§","©","®","™","°","±","×","÷","¶","∆","∏","∑","∞","√","∫","≈","≠","≤","≥","≡","∼"}
|
132 |
# Add every other one-letter word to the list of words to cancel, except 'I' and 'a'
|
133 |
for letter in "bcdefghjklmnopqrstuvwxyz1234567890": # All lowercase letters except 'a'
|
|
|
173 |
return list_of_lines
|
174 |
|
175 |
|
176 |
+
def get_box_points(box):
|
177 |
+
"""
|
178 |
+
Returns all critical points of a box: corners and midpoints of edges.
|
179 |
|
180 |
+
Parameters:
|
181 |
+
- box (array): Bounding box coordinates [xmin, ymin, xmax, ymax].
|
182 |
|
183 |
+
Returns:
|
184 |
+
- numpy.array: Array of critical points.
|
185 |
+
"""
|
186 |
xmin, ymin, xmax, ymax = box
|
187 |
return np.array([
|
188 |
[xmin, ymin], # Bottom-left corner
|
|
|
196 |
])
|
197 |
|
198 |
def min_distance_between_boxes(box1, box2):
|
199 |
+
"""
|
200 |
+
Computes the minimum distance between two boxes considering all critical points.
|
201 |
+
|
202 |
+
Parameters:
|
203 |
+
- box1 (array): First bounding box coordinates.
|
204 |
+
- box2 (array): Second bounding box coordinates.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
- float: The minimum distance between the two boxes.
|
208 |
+
"""
|
209 |
points1 = get_box_points(box1)
|
210 |
points2 = get_box_points(box2)
|
211 |
|
|
|
218 |
return min_dist
|
219 |
|
220 |
def are_close(box1, box2, threshold=50):
|
221 |
+
"""
|
222 |
+
Determines if boxes are close based on their corners and center points.
|
223 |
+
|
224 |
+
Parameters:
|
225 |
+
- box1 (array): First bounding box coordinates.
|
226 |
+
- box2 (array): Second bounding box coordinates.
|
227 |
+
- threshold (int): Distance threshold for determining closeness.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
- bool: True if boxes are close, otherwise False.
|
231 |
+
"""
|
232 |
corners1 = np.array([
|
233 |
[box1[0], box1[1]], [box1[0], box1[3]], [box1[2], box1[1]], [box1[2], box1[3]],
|
234 |
[(box1[0]+box1[2])/2, box1[1]], [(box1[0]+box1[2])/2, box1[3]],
|
|
|
246 |
return False
|
247 |
|
248 |
def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
|
249 |
+
"""
|
250 |
+
Find the closest box to the given text box within a specified threshold.
|
251 |
+
|
252 |
+
Parameters:
|
253 |
+
- text_box (array): The text box coordinates.
|
254 |
+
- all_boxes (list): List of all bounding boxes.
|
255 |
+
- labels (list): List of labels corresponding to the boxes.
|
256 |
+
- threshold (float): Distance threshold for determining closeness.
|
257 |
+
- iou_threshold (float): IoU threshold for determining if a text is inside a sequenceFlow.
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
- int or None: Index of the closest box or None if no box is close enough.
|
261 |
+
"""
|
262 |
min_distance = float('inf')
|
263 |
closest_index = None
|
264 |
|
265 |
+
# Check if the text is inside a sequenceFlow
|
266 |
for j in range(len(all_boxes)):
|
267 |
+
if proportion_inside(text_box, all_boxes[j]) > iou_threshold and labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
268 |
return j
|
269 |
|
270 |
for i, box in enumerate(all_boxes):
|
|
|
287 |
return None
|
288 |
|
289 |
|
|
|
290 |
def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, percentage_thresh=0.8):
|
291 |
+
"""
|
292 |
+
Maps text boxes to task boxes and groups texts within each task based on proximity.
|
293 |
+
|
294 |
+
Parameters:
|
295 |
+
- task_boxes (list): List of task bounding boxes.
|
296 |
+
- text_boxes (list): List of text bounding boxes.
|
297 |
+
- texts (list): List of texts corresponding to the text boxes.
|
298 |
+
- min_dist (float): Minimum distance threshold for grouping.
|
299 |
+
- iou_threshold (float): IoU threshold for determining if text is inside a task box.
|
300 |
+
- percentage_thresh (float): Percentage threshold for determining if text boxes are close.
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
- tuple: Grouped task-related texts, their bounding boxes, grouped information texts, and their bounding boxes.
|
304 |
+
"""
|
305 |
G = nx.Graph()
|
306 |
|
307 |
# Map each text box to the nearest task box
|
308 |
task_to_texts = {i: [] for i in range(len(task_boxes))}
|
309 |
+
information_texts = [] # Texts not inside any task box
|
310 |
text_to_task_mapped = [False] * len(text_boxes)
|
311 |
|
312 |
for idx, text_box in enumerate(text_boxes):
|
313 |
mapped = False
|
314 |
for jdx, task_box in enumerate(task_boxes):
|
315 |
+
if proportion_inside(text_box, task_box) > iou_threshold:
|
316 |
task_to_texts[jdx].append(idx)
|
317 |
text_to_task_mapped[idx] = True
|
318 |
mapped = True
|
|
|
416 |
return all_grouped_texts, sentence_boxes, information_grouped_texts, info_sentence_boxes
|
417 |
|
418 |
|
419 |
+
def mapping_text(full_pred, text_pred, print_sentences=False, percentage_thresh=0.6, scale=1.0, iou_threshold=0.5):
|
420 |
+
"""
|
421 |
+
Map the extracted texts to the predicted bounding boxes.
|
422 |
|
423 |
+
Parameters:
|
424 |
+
- full_pred (dict): Full prediction dictionary containing boxes, labels, BPMN IDs, and pool dictionary.
|
425 |
+
- text_pred (list): List containing text predictions and their bounding boxes.
|
426 |
+
- print_sentences (bool): Whether to print the sentences and their bounding boxes.
|
427 |
+
- percentage_thresh (float): Percentage threshold for determining closeness.
|
428 |
+
- scale (float): Scale factor for rescaling bounding boxes.
|
429 |
+
- iou_threshold (float): IoU threshold for determining if text is inside a bounding box.
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
- dict: Text mapping for BPMN elements.
|
433 |
+
"""
|
434 |
boxes = rescale(scale, full_pred['boxes'])
|
435 |
|
436 |
min_dist = 200
|
437 |
labels = full_pred['labels']
|
438 |
avoid = [list(class_dict.values()).index('pool'), list(class_dict.values()).index('lane'), list(class_dict.values()).index('sequenceFlow'), list(class_dict.values()).index('messageFlow'), list(class_dict.values()).index('dataAssociation')]
|
439 |
for i in range(len(boxes)):
|
440 |
+
box1 = boxes[i]
|
441 |
+
if labels[i] in avoid:
|
442 |
+
continue
|
443 |
+
for j in range(i + 1, len(boxes)):
|
444 |
+
box2 = boxes[j]
|
445 |
+
if labels[j] in avoid:
|
446 |
continue
|
447 |
+
dist = min_distance_between_boxes(box1, box2)
|
448 |
+
min_dist = min(min_dist, dist)
|
|
|
|
|
|
|
|
|
449 |
|
450 |
+
# Print the minimum distance between boxes
|
451 |
+
# print("Minimum distance between boxes:", min_dist)
|
452 |
|
453 |
text_pred[0] = rescale(scale, text_pred[0])
|
454 |
task_boxes = [box for i, box in enumerate(boxes) if full_pred['labels'][i] == list(class_dict.values()).index('task')]
|
455 |
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_pred[0], text_pred[1], min_dist=min_dist)
|
456 |
BPMN_id = set(full_pred['BPMN_id']) # This ensures uniqueness of task names
|
457 |
text_mapping = {id: '' for id in BPMN_id}
|
|
|
458 |
|
459 |
if print_sentences:
|
460 |
for sentence, box in zip(grouped_sentences, sentence_bounding_boxes):
|
|
|
466 |
# Map the grouped sentences to the corresponding task
|
467 |
for i in range(len(sentence_bounding_boxes)):
|
468 |
for j in range(len(boxes)):
|
469 |
+
if proportion_inside(sentence_bounding_boxes[i], boxes[j]) > iou_threshold and full_pred['labels'][j] == list(class_dict.values()).index('task'):
|
470 |
+
text_mapping[full_pred['BPMN_id'][j]] = grouped_sentences[i]
|
471 |
|
472 |
# Map the grouped sentences to the corresponding pool
|
473 |
for key, elements in full_pred['pool_dict'].items():
|
|
|
475 |
continue
|
476 |
else:
|
477 |
for i in range(len(info_boxes)):
|
478 |
+
# Find the position of the key in BPMN_id
|
479 |
position = list(full_pred['BPMN_id']).index(key)
|
480 |
+
if proportion_inside(info_boxes[i], boxes[position]) > iou_threshold:
|
481 |
text_mapping[key] = info_texts[i]
|
482 |
info_texts[i] = '' # Clear the text to avoid re-use
|
483 |
|
|
|
484 |
for i in range(len(info_boxes)):
|
485 |
if is_vertical(info_boxes[i]):
|
486 |
for j in range(len(boxes)):
|
487 |
+
if proportion_inside(info_boxes[i], boxes[j]) > 0 and full_pred['labels'][j] == list(class_dict.values()).index('pool'):
|
488 |
print("Text:", info_texts[i], "associate with ", full_pred['BPMN_id'][j])
|
489 |
bpmn_id = full_pred['BPMN_id'][j]
|
490 |
# Append new text or create new entry if not existing
|
|
|
501 |
for j in range(len(boxes)):
|
502 |
if info_texts[i] == '':
|
503 |
continue # Skip if there's no text
|
504 |
+
if (proportion_inside(info_boxes[i], boxes[j]) > 0 or are_close(info_boxes[i], boxes[j], threshold=percentage_thresh * min_dist)) and (full_pred['labels'][j] == list(class_dict.values()).index('event')
|
505 |
or full_pred['labels'][j] == list(class_dict.values()).index('messageEvent')
|
506 |
or full_pred['labels'][j] == list(class_dict.values()).index('timerEvent')
|
507 |
+
or full_pred['labels'][j] == list(class_dict.values()).index('dataObject')):
|
508 |
bpmn_id = full_pred['BPMN_id'][j]
|
509 |
# Append new text or create new entry if not existing
|
510 |
if bpmn_id in text_mapping:
|
|
|
518 |
if info_texts[i] == '' or is_vertical(info_boxes[i]):
|
519 |
continue # Skip if there's no text
|
520 |
# Find the closest box within the defined threshold
|
521 |
+
closest_index = find_closest_box(info_boxes[i], boxes, full_pred['labels'], threshold=4 * min_dist)
|
522 |
if closest_index is not None and (full_pred['labels'][closest_index] == list(class_dict.values()).index('sequenceFlow') or full_pred['labels'][closest_index] == list(class_dict.values()).index('messageFlow')):
|
523 |
bpmn_id = full_pred['BPMN_id'][closest_index]
|
524 |
# Append new text or create new entry if not existing
|
|
|
532 |
print("Text Mapping:", text_mapping)
|
533 |
print("Information Texts left:", info_texts)
|
534 |
|
535 |
+
return text_mapping
|
modules/dataset_loader.py
CHANGED
@@ -1,7 +1,3 @@
|
|
1 |
-
from torchvision.models.detection import keypointrcnn_resnet50_fpn
|
2 |
-
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
3 |
-
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
|
4 |
-
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
|
5 |
import random
|
6 |
import torch
|
7 |
from torch.utils.data import Dataset
|
@@ -9,43 +5,60 @@ import torchvision.transforms.functional as F
|
|
9 |
import numpy as np
|
10 |
from torch.utils.data.dataloader import default_collate
|
11 |
import cv2
|
12 |
-
|
13 |
-
from torch.utils.data import DataLoader, Subset, ConcatDataset
|
14 |
-
import streamlit as st
|
15 |
from modules.utils import object_dict, arrow_dict, resize_boxes, resize_keypoints
|
|
|
|
|
16 |
|
17 |
class RandomCrop:
|
18 |
-
def __init__(self, new_size=(1333,800),crop_fraction=0.5, min_objects=4):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
self.crop_fraction = crop_fraction
|
20 |
self.min_objects = min_objects
|
21 |
self.new_size = new_size
|
22 |
|
23 |
def __call__(self, image, target):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
new_w1, new_h1 = self.new_size
|
25 |
w, h = image.size
|
26 |
new_w = int(w * self.crop_fraction)
|
27 |
-
new_h = int(new_w*new_h1/new_w1)
|
28 |
-
|
29 |
-
i=0
|
30 |
-
for i in range(4):
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
|
38 |
-
if new_h >= h:
|
39 |
-
|
40 |
|
41 |
boxes = target["boxes"]
|
42 |
if 'keypoints' in target:
|
43 |
keypoints = target["keypoints"]
|
44 |
else:
|
45 |
keypoints = []
|
46 |
-
for
|
47 |
-
keypoints.append(torch.zeros((2,3)))
|
48 |
-
|
49 |
|
50 |
# Attempt to find a suitable crop region
|
51 |
success = False
|
@@ -82,7 +95,7 @@ class RandomCrop:
|
|
82 |
class RandomFlip:
|
83 |
def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
|
84 |
"""
|
85 |
-
|
86 |
|
87 |
Parameters:
|
88 |
- h_flip_prob (float): Probability of applying a horizontal flip to the image.
|
@@ -93,7 +106,7 @@ class RandomFlip:
|
|
93 |
|
94 |
def __call__(self, image, target):
|
95 |
"""
|
96 |
-
|
97 |
|
98 |
Parameters:
|
99 |
- image (PIL Image): The image to be flipped.
|
@@ -143,12 +156,12 @@ class RandomFlip:
|
|
143 |
target['keypoints'] = torch.stack(new_keypoints)
|
144 |
|
145 |
return image, target
|
146 |
-
|
147 |
|
148 |
class RandomRotate:
|
149 |
def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
|
150 |
"""
|
151 |
-
|
152 |
|
153 |
Parameters:
|
154 |
- max_rotate_deg (int): Maximum degree to rotate the image.
|
@@ -159,7 +172,7 @@ class RandomRotate:
|
|
159 |
|
160 |
def __call__(self, image, target):
|
161 |
"""
|
162 |
-
Randomly
|
163 |
|
164 |
Parameters:
|
165 |
- image (PIL Image): The image to be rotated.
|
@@ -170,7 +183,7 @@ class RandomRotate:
|
|
170 |
"""
|
171 |
if random.random() < self.rotate_proba:
|
172 |
angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
|
173 |
-
image = F.rotate(image, angle, expand=False, fill=
|
174 |
|
175 |
# Rotate bounding boxes
|
176 |
w, h = image.size
|
@@ -194,7 +207,16 @@ class RandomRotate:
|
|
194 |
|
195 |
def rotate_box(self, box, angle, cx, cy):
|
196 |
"""
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
"""
|
199 |
x1, y1, x2, y2 = box
|
200 |
corners = torch.tensor([
|
@@ -214,7 +236,16 @@ class RandomRotate:
|
|
214 |
|
215 |
def rotate_keypoints(self, keypoints, angle, cx, cy):
|
216 |
"""
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
"""
|
219 |
new_keypoints = []
|
220 |
for kp in keypoints:
|
@@ -226,50 +257,89 @@ class RandomRotate:
|
|
226 |
return torch.stack(new_keypoints)
|
227 |
|
228 |
def rotate_90_box(box, angle, w, h):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
x1, y1, x2, y2 = box
|
230 |
if angle == 90:
|
231 |
-
return torch.tensor([y1,h-x2,y2,h-x1])
|
232 |
elif angle == 270 or angle == -90:
|
233 |
-
return torch.tensor([w-y2,x1,w-y1,x2])
|
234 |
else:
|
235 |
print("angle not supported")
|
236 |
|
237 |
def rotate_90_keypoints(kp, angle, w, h):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
# Extract coordinates and visibility from each keypoint tensor
|
239 |
x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
|
240 |
x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
|
241 |
# Swap x and y coordinates for each keypoint
|
242 |
if angle == 90:
|
243 |
-
new = [[y1, h-x1, v1], [y2, h-x2, v2]]
|
244 |
elif angle == 270 or angle == -90:
|
245 |
-
new = [[w-y1, x1, v1], [w-y2, x2, v2]]
|
246 |
|
247 |
return torch.tensor(new, dtype=torch.float32)
|
248 |
-
|
249 |
|
250 |
def rotate_vertical(image, target):
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
new_boxes = []
|
253 |
-
angle = random.choice([-90,90])
|
254 |
image = F.rotate(image, angle, expand=True, fill=200)
|
255 |
for box in target["boxes"]:
|
256 |
new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
|
257 |
new_boxes.append(new_box)
|
258 |
target["boxes"] = torch.stack(new_boxes)
|
259 |
-
|
260 |
if 'keypoints' in target:
|
261 |
-
new_kp = []
|
262 |
-
for kp in target['keypoints']:
|
263 |
new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
|
264 |
new_kp.append(new_key)
|
265 |
target['keypoints'] = torch.stack(new_kp)
|
266 |
return image, target
|
267 |
|
|
|
|
|
|
|
268 |
|
269 |
-
|
270 |
-
|
|
|
|
|
271 |
|
272 |
-
|
|
|
|
|
273 |
original_size = image.size
|
274 |
# Calculate scale to fit the new size while maintaining aspect ratio
|
275 |
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
|
@@ -302,8 +372,24 @@ def resize_and_pad(image, target, new_size=(1333, 800)):
|
|
302 |
return image, target
|
303 |
|
304 |
class BPMN_Dataset(Dataset):
|
305 |
-
def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2,
|
306 |
-
flip_transform=None, rotate_transform=None, new_size=(1333,1333), keep_ratio=0.1, resize=True, model_type='object'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
self.annotations = annotations
|
308 |
print(f"Loaded {len(self.annotations)} annotations.")
|
309 |
self.transform = transform
|
@@ -322,15 +408,30 @@ class BPMN_Dataset(Dataset):
|
|
322 |
self.rotate_90_proba = rotate_90_proba
|
323 |
|
324 |
def __len__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
return len(self.annotations)
|
326 |
|
327 |
def __getitem__(self, idx):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
annotation = self.annotations[idx]
|
329 |
image = annotation.img.convert("RGB")
|
330 |
boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
|
331 |
labels_names = [ann for ann in annotation.categories]
|
332 |
|
333 |
-
# Only keep the labels, boxes and keypoints that are in the class_dict
|
334 |
kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
|
335 |
boxes = boxes[kept_indices]
|
336 |
labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
|
@@ -351,7 +452,7 @@ class BPMN_Dataset(Dataset):
|
|
351 |
if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
|
352 |
# Fill the keypoints tensor for this annotation, mark as visible (1)
|
353 |
kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
|
354 |
-
kp = kp[
|
355 |
visible = np.ones((kp.shape[0], 1), dtype=np.float32)
|
356 |
kp = np.hstack([kp, visible])
|
357 |
keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
|
@@ -359,17 +460,17 @@ class BPMN_Dataset(Dataset):
|
|
359 |
|
360 |
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
361 |
|
362 |
-
if self.model_type == 'object':
|
363 |
target = {
|
364 |
"boxes": boxes,
|
365 |
"labels": labels_id,
|
366 |
-
#"area": area,
|
367 |
}
|
368 |
elif self.model_type == 'arrow':
|
369 |
target = {
|
370 |
"boxes": boxes,
|
371 |
"labels": labels_id,
|
372 |
-
#"area": area,
|
373 |
"keypoints": keypoints,
|
374 |
}
|
375 |
|
@@ -384,7 +485,7 @@ class BPMN_Dataset(Dataset):
|
|
384 |
# Randomly apply the custom cropping transform
|
385 |
if self.crop_transform and random.random() < self.crop_prob:
|
386 |
image, target = self.crop_transform(image, target)
|
387 |
-
|
388 |
# Rotate vertical image
|
389 |
if random.random() < self.rotate_90_proba:
|
390 |
image, target = rotate_vertical(image, target)
|
@@ -394,12 +495,12 @@ class BPMN_Dataset(Dataset):
|
|
394 |
# Center and pad the image while keeping the aspect ratio
|
395 |
image, target = resize_and_pad(image, target, self.new_size)
|
396 |
else:
|
397 |
-
target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), self.new_size)
|
398 |
if 'area' in target:
|
399 |
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
|
400 |
if 'keypoints' in target:
|
401 |
for i in range(len(target['keypoints'])):
|
402 |
-
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), self.new_size)
|
403 |
image = F.resize(image, (self.new_size[1], self.new_size[0]))
|
404 |
|
405 |
return self.transform(image), target
|
@@ -429,15 +530,15 @@ def collate_fn(batch):
|
|
429 |
return images, targets
|
430 |
|
431 |
|
432 |
-
|
433 |
-
def create_loader(new_size,transformation, annotations1, annotations2=None,
|
434 |
batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
|
435 |
h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
|
436 |
-
seed=42, resize=True, keep_ratio=0.1, model_type
|
437 |
"""
|
438 |
-
|
439 |
|
440 |
Parameters:
|
|
|
441 |
- transformation (callable): Transformation function to apply to each image (e.g., normalization).
|
442 |
- annotations1 (list): Primary list of annotations.
|
443 |
- annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
|
@@ -447,15 +548,20 @@ def create_loader(new_size,transformation, annotations1, annotations2=None,
|
|
447 |
- min_objects (int): Minimum number of objects required to be within the crop.
|
448 |
- h_flip_prob (float): Probability of applying horizontal flip.
|
449 |
- v_flip_prob (float): Probability of applying vertical flip.
|
|
|
|
|
|
|
450 |
- seed (int): Seed for random number generators for reproducibility.
|
451 |
- resize (bool): Flag indicating whether to resize images after transformations.
|
|
|
|
|
452 |
|
453 |
Returns:
|
454 |
- DataLoader: Configured data loader for the dataset.
|
455 |
"""
|
456 |
|
457 |
# Initialize custom transformations for cropping and flipping
|
458 |
-
custom_crop_transform = RandomCrop(new_size,crop_fraction, min_objects)
|
459 |
custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
|
460 |
custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
|
461 |
|
@@ -497,4 +603,4 @@ def create_loader(new_size,transformation, annotations1, annotations2=None,
|
|
497 |
# Create the DataLoader with the dataset
|
498 |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
|
499 |
|
500 |
-
return data_loader
|
|
|
|
|
|
|
|
|
|
|
1 |
import random
|
2 |
import torch
|
3 |
from torch.utils.data import Dataset
|
|
|
5 |
import numpy as np
|
6 |
from torch.utils.data.dataloader import default_collate
|
7 |
import cv2
|
8 |
+
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
|
|
|
|
|
9 |
from modules.utils import object_dict, arrow_dict, resize_boxes, resize_keypoints
|
10 |
+
import torchvision.transforms.functional as F
|
11 |
+
import torch
|
12 |
|
13 |
class RandomCrop:
|
14 |
+
def __init__(self, new_size=(1333, 800), crop_fraction=0.5, min_objects=4):
|
15 |
+
"""
|
16 |
+
Initialize the RandomCrop transformation.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
- new_size (tuple): The target size for the image after cropping.
|
20 |
+
- crop_fraction (float): The fraction of the original width to use when cropping.
|
21 |
+
- min_objects (int): Minimum number of objects required to be within the crop.
|
22 |
+
"""
|
23 |
self.crop_fraction = crop_fraction
|
24 |
self.min_objects = min_objects
|
25 |
self.new_size = new_size
|
26 |
|
27 |
def __call__(self, image, target):
|
28 |
+
"""
|
29 |
+
Apply the RandomCrop transformation to the image and its target.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
- image (PIL Image): The image to be cropped.
|
33 |
+
- target (dict): The target dictionary containing 'boxes' and optional 'keypoints'.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
- PIL Image, dict: The cropped image and its updated target dictionary.
|
37 |
+
"""
|
38 |
new_w1, new_h1 = self.new_size
|
39 |
w, h = image.size
|
40 |
new_w = int(w * self.crop_fraction)
|
41 |
+
new_h = int(new_w * new_h1 / new_w1)
|
42 |
+
|
43 |
+
i = 0
|
44 |
+
for i in range(4): # Try 4 times to adjust new_w and new_h if new_h >= h
|
45 |
+
if new_h >= h:
|
46 |
+
i += 0.05
|
47 |
+
new_w = int(w * (self.crop_fraction - i))
|
48 |
+
new_h = int(new_w * new_h1 / new_w1)
|
49 |
+
if new_h < h:
|
50 |
+
continue
|
51 |
|
52 |
+
if new_h >= h: # If still not valid, return original image and target
|
53 |
+
return image, target
|
54 |
|
55 |
boxes = target["boxes"]
|
56 |
if 'keypoints' in target:
|
57 |
keypoints = target["keypoints"]
|
58 |
else:
|
59 |
keypoints = []
|
60 |
+
for _ in range(len(boxes)):
|
61 |
+
keypoints.append(torch.zeros((2, 3)))
|
|
|
62 |
|
63 |
# Attempt to find a suitable crop region
|
64 |
success = False
|
|
|
95 |
class RandomFlip:
|
96 |
def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
|
97 |
"""
|
98 |
+
Initialize the RandomFlip transformation with probabilities for flipping.
|
99 |
|
100 |
Parameters:
|
101 |
- h_flip_prob (float): Probability of applying a horizontal flip to the image.
|
|
|
106 |
|
107 |
def __call__(self, image, target):
|
108 |
"""
|
109 |
+
Apply random horizontal and/or vertical flip to the image and updates target data accordingly.
|
110 |
|
111 |
Parameters:
|
112 |
- image (PIL Image): The image to be flipped.
|
|
|
156 |
target['keypoints'] = torch.stack(new_keypoints)
|
157 |
|
158 |
return image, target
|
159 |
+
|
160 |
|
161 |
class RandomRotate:
|
162 |
def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
|
163 |
"""
|
164 |
+
Initialize the RandomRotate transformation with a maximum rotation angle and probability of rotating.
|
165 |
|
166 |
Parameters:
|
167 |
- max_rotate_deg (int): Maximum degree to rotate the image.
|
|
|
172 |
|
173 |
def __call__(self, image, target):
|
174 |
"""
|
175 |
+
Randomly rotate the image and updates the target data accordingly.
|
176 |
|
177 |
Parameters:
|
178 |
- image (PIL Image): The image to be rotated.
|
|
|
183 |
"""
|
184 |
if random.random() < self.rotate_proba:
|
185 |
angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
|
186 |
+
image = F.rotate(image, angle, expand=False, fill=255)
|
187 |
|
188 |
# Rotate bounding boxes
|
189 |
w, h = image.size
|
|
|
207 |
|
208 |
def rotate_box(self, box, angle, cx, cy):
|
209 |
"""
|
210 |
+
Rotate a bounding box by a given angle around the center of the image.
|
211 |
+
|
212 |
+
Parameters:
|
213 |
+
- box (tensor): The bounding box to be rotated.
|
214 |
+
- angle (float): The angle to rotate the box.
|
215 |
+
- cx (float): The x-coordinate of the image center.
|
216 |
+
- cy (float): The y-coordinate of the image center.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
- tensor: The rotated bounding box.
|
220 |
"""
|
221 |
x1, y1, x2, y2 = box
|
222 |
corners = torch.tensor([
|
|
|
236 |
|
237 |
def rotate_keypoints(self, keypoints, angle, cx, cy):
|
238 |
"""
|
239 |
+
Rotate keypoints by a given angle around the center of the image.
|
240 |
+
|
241 |
+
Parameters:
|
242 |
+
- keypoints (tensor): The keypoints to be rotated.
|
243 |
+
- angle (float): The angle to rotate the keypoints.
|
244 |
+
- cx (float): The x-coordinate of the image center.
|
245 |
+
- cy (float): The y-coordinate of the image center.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
- tensor: The rotated keypoints.
|
249 |
"""
|
250 |
new_keypoints = []
|
251 |
for kp in keypoints:
|
|
|
257 |
return torch.stack(new_keypoints)
|
258 |
|
259 |
def rotate_90_box(box, angle, w, h):
|
260 |
+
"""
|
261 |
+
Rotate a bounding box by 90 degrees.
|
262 |
+
|
263 |
+
Parameters:
|
264 |
+
- box (tensor): The bounding box to be rotated.
|
265 |
+
- angle (int): The angle to rotate the box (90 or -90 degrees).
|
266 |
+
- w (int): The width of the image.
|
267 |
+
- h (int): The height of the image.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
- tensor: The rotated bounding box.
|
271 |
+
"""
|
272 |
x1, y1, x2, y2 = box
|
273 |
if angle == 90:
|
274 |
+
return torch.tensor([y1, h - x2, y2, h - x1])
|
275 |
elif angle == 270 or angle == -90:
|
276 |
+
return torch.tensor([w - y2, x1, w - y1, x2])
|
277 |
else:
|
278 |
print("angle not supported")
|
279 |
|
280 |
def rotate_90_keypoints(kp, angle, w, h):
|
281 |
+
"""
|
282 |
+
Rotate keypoints by 90 degrees.
|
283 |
+
|
284 |
+
Parameters:
|
285 |
+
- kp (tensor): The keypoints to be rotated.
|
286 |
+
- angle (int): The angle to rotate the keypoints (90 or -90 degrees).
|
287 |
+
- w (int): The width of the image.
|
288 |
+
- h (int): The height of the image.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
- tensor: The rotated keypoints.
|
292 |
+
"""
|
293 |
# Extract coordinates and visibility from each keypoint tensor
|
294 |
x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
|
295 |
x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
|
296 |
# Swap x and y coordinates for each keypoint
|
297 |
if angle == 90:
|
298 |
+
new = [[y1, h - x1, v1], [y2, h - x2, v2]]
|
299 |
elif angle == 270 or angle == -90:
|
300 |
+
new = [[w - y1, x1, v1], [w - y2, x2, v2]]
|
301 |
|
302 |
return torch.tensor(new, dtype=torch.float32)
|
|
|
303 |
|
304 |
def rotate_vertical(image, target):
|
305 |
+
"""
|
306 |
+
Rotate the image and target if the image is vertical.
|
307 |
+
|
308 |
+
Parameters:
|
309 |
+
- image (PIL Image): The image to be rotated.
|
310 |
+
- target (dict): The target dictionary containing 'boxes' and 'keypoints'.
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
- PIL Image, dict: The rotated image and its updated target dictionary.
|
314 |
+
"""
|
315 |
new_boxes = []
|
316 |
+
angle = random.choice([-90, 90])
|
317 |
image = F.rotate(image, angle, expand=True, fill=200)
|
318 |
for box in target["boxes"]:
|
319 |
new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
|
320 |
new_boxes.append(new_box)
|
321 |
target["boxes"] = torch.stack(new_boxes)
|
322 |
+
|
323 |
if 'keypoints' in target:
|
324 |
+
new_kp = []
|
325 |
+
for kp in target['keypoints']:
|
326 |
new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
|
327 |
new_kp.append(new_key)
|
328 |
target['keypoints'] = torch.stack(new_kp)
|
329 |
return image, target
|
330 |
|
331 |
+
def resize_and_pad(image, target, new_size=(1333, 800)):
|
332 |
+
"""
|
333 |
+
Resize and pad the image and target to the specified new size while maintaining the aspect ratio.
|
334 |
|
335 |
+
Parameters:
|
336 |
+
- image (PIL Image): The image to be resized and padded.
|
337 |
+
- target (dict): The target dictionary containing 'boxes' and optional 'keypoints'.
|
338 |
+
- new_size (tuple): The target size for the image after resizing and padding.
|
339 |
|
340 |
+
Returns:
|
341 |
+
- PIL Image, dict: The resized and padded image and its updated target dictionary.
|
342 |
+
"""
|
343 |
original_size = image.size
|
344 |
# Calculate scale to fit the new size while maintaining aspect ratio
|
345 |
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
|
|
|
372 |
return image, target
|
373 |
|
374 |
class BPMN_Dataset(Dataset):
|
375 |
+
def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2,
|
376 |
+
flip_transform=None, rotate_transform=None, new_size=(1333, 1333), keep_ratio=0.1, resize=True, model_type='object'):
|
377 |
+
"""
|
378 |
+
Initialize the BPMN_Dataset with annotations and optional transformations.
|
379 |
+
|
380 |
+
Parameters:
|
381 |
+
- annotations (list): List of annotations for the dataset.
|
382 |
+
- transform (callable, optional): Transformation function to apply to each image.
|
383 |
+
- crop_transform (callable, optional): Custom cropping transformation.
|
384 |
+
- crop_prob (float): Probability of applying the crop transformation.
|
385 |
+
- rotate_90_proba (float): Probability of rotating the image by 90 degrees.
|
386 |
+
- flip_transform (callable, optional): Custom flipping transformation.
|
387 |
+
- rotate_transform (callable, optional): Custom rotation transformation.
|
388 |
+
- new_size (tuple): Target size for the images.
|
389 |
+
- keep_ratio (float): Probability of keeping the aspect ratio during resizing.
|
390 |
+
- resize (bool): Flag indicating whether to resize images after transformations.
|
391 |
+
- model_type (str): Type of model ('object' or 'arrow') to determine the target dictionary.
|
392 |
+
"""
|
393 |
self.annotations = annotations
|
394 |
print(f"Loaded {len(self.annotations)} annotations.")
|
395 |
self.transform = transform
|
|
|
408 |
self.rotate_90_proba = rotate_90_proba
|
409 |
|
410 |
def __len__(self):
|
411 |
+
"""
|
412 |
+
Return the number of annotations in the dataset.
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
- int: The number of annotations.
|
416 |
+
"""
|
417 |
return len(self.annotations)
|
418 |
|
419 |
def __getitem__(self, idx):
|
420 |
+
"""
|
421 |
+
Get an item (image and target) from the dataset at the specified index.
|
422 |
+
|
423 |
+
Parameters:
|
424 |
+
- idx (int): The index of the item to retrieve.
|
425 |
+
|
426 |
+
Returns:
|
427 |
+
- PIL Image, dict: The transformed image and its updated target dictionary.
|
428 |
+
"""
|
429 |
annotation = self.annotations[idx]
|
430 |
image = annotation.img.convert("RGB")
|
431 |
boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
|
432 |
labels_names = [ann for ann in annotation.categories]
|
433 |
|
434 |
+
# Only keep the labels, boxes, and keypoints that are in the class_dict
|
435 |
kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
|
436 |
boxes = boxes[kept_indices]
|
437 |
labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
|
|
|
452 |
if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
|
453 |
# Fill the keypoints tensor for this annotation, mark as visible (1)
|
454 |
kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
|
455 |
+
kp = kp[:, :2]
|
456 |
visible = np.ones((kp.shape[0], 1), dtype=np.float32)
|
457 |
kp = np.hstack([kp, visible])
|
458 |
keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
|
|
|
460 |
|
461 |
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
462 |
|
463 |
+
if self.model_type == 'object':
|
464 |
target = {
|
465 |
"boxes": boxes,
|
466 |
"labels": labels_id,
|
467 |
+
# "area": area,
|
468 |
}
|
469 |
elif self.model_type == 'arrow':
|
470 |
target = {
|
471 |
"boxes": boxes,
|
472 |
"labels": labels_id,
|
473 |
+
# "area": area,
|
474 |
"keypoints": keypoints,
|
475 |
}
|
476 |
|
|
|
485 |
# Randomly apply the custom cropping transform
|
486 |
if self.crop_transform and random.random() < self.crop_prob:
|
487 |
image, target = self.crop_transform(image, target)
|
488 |
+
|
489 |
# Rotate vertical image
|
490 |
if random.random() < self.rotate_90_proba:
|
491 |
image, target = rotate_vertical(image, target)
|
|
|
495 |
# Center and pad the image while keeping the aspect ratio
|
496 |
image, target = resize_and_pad(image, target, self.new_size)
|
497 |
else:
|
498 |
+
target['boxes'] = resize_boxes(target['boxes'], (image.size[0], image.size[1]), self.new_size)
|
499 |
if 'area' in target:
|
500 |
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
|
501 |
if 'keypoints' in target:
|
502 |
for i in range(len(target['keypoints'])):
|
503 |
+
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0], image.size[1]), self.new_size)
|
504 |
image = F.resize(image, (self.new_size[1], self.new_size[0]))
|
505 |
|
506 |
return self.transform(image), target
|
|
|
530 |
return images, targets
|
531 |
|
532 |
|
533 |
+
def create_loader(new_size, transformation, annotations1, annotations2=None,
|
|
|
534 |
batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
|
535 |
h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
|
536 |
+
seed=42, resize=True, keep_ratio=0.1, model_type='object'):
|
537 |
"""
|
538 |
+
Create a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
|
539 |
|
540 |
Parameters:
|
541 |
+
- new_size (tuple): The target size for the images.
|
542 |
- transformation (callable): Transformation function to apply to each image (e.g., normalization).
|
543 |
- annotations1 (list): Primary list of annotations.
|
544 |
- annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
|
|
|
548 |
- min_objects (int): Minimum number of objects required to be within the crop.
|
549 |
- h_flip_prob (float): Probability of applying horizontal flip.
|
550 |
- v_flip_prob (float): Probability of applying vertical flip.
|
551 |
+
- max_rotate_deg (int): Maximum degree to rotate the image.
|
552 |
+
- rotate_90_proba (float): Probability of rotating the image by 90 degrees.
|
553 |
+
- rotate_proba (float): Probability of applying rotation to the image.
|
554 |
- seed (int): Seed for random number generators for reproducibility.
|
555 |
- resize (bool): Flag indicating whether to resize images after transformations.
|
556 |
+
- keep_ratio (float): Probability of keeping the aspect ratio during resizing.
|
557 |
+
- model_type (str): Type of model ('object' or 'arrow') to determine the target dictionary.
|
558 |
|
559 |
Returns:
|
560 |
- DataLoader: Configured data loader for the dataset.
|
561 |
"""
|
562 |
|
563 |
# Initialize custom transformations for cropping and flipping
|
564 |
+
custom_crop_transform = RandomCrop(new_size, crop_fraction, min_objects)
|
565 |
custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
|
566 |
custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
|
567 |
|
|
|
603 |
# Create the DataLoader with the dataset
|
604 |
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
|
605 |
|
606 |
+
return data_loader
|
modules/eval.py
CHANGED
@@ -9,6 +9,18 @@ from builtins import dict
|
|
9 |
|
10 |
|
11 |
def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
exception = ['pool', 'lane']
|
13 |
|
14 |
idxs = np.argsort(scores) # Sort the boxes according to their scores in ascending order
|
@@ -40,6 +52,19 @@ def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
|
|
40 |
|
41 |
|
42 |
def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distance_treshold=15):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
for idx, (key1, key2) in enumerate(keypoints):
|
44 |
if labels[idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
45 |
list(model_dict.values()).index('messageFlow'),
|
@@ -49,14 +74,26 @@ def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distanc
|
|
49 |
distance = np.linalg.norm(key1[:2] - key2[:2])
|
50 |
if distance < distance_treshold:
|
51 |
print('Key modified for index:', idx)
|
52 |
-
x_new,y_new, x,y = find_other_keypoint(idx, keypoints, boxes)
|
53 |
-
keypoints[idx][0][:2] = [x_new,y_new]
|
54 |
-
keypoints[idx][1][:2] = [x,y]
|
55 |
|
56 |
return keypoints
|
57 |
|
58 |
|
59 |
def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
model.eval()
|
61 |
with torch.no_grad():
|
62 |
image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
|
@@ -73,7 +110,7 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
|
|
73 |
|
74 |
selected_boxes = non_maximum_suppression(boxes, scores, labels=labels, iou_threshold=iou_threshold)
|
75 |
|
76 |
-
#
|
77 |
vertical = 0
|
78 |
for i in range(len(labels)):
|
79 |
if labels[i] != list(object_dict.values()).index('task'):
|
@@ -87,12 +124,12 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
|
|
87 |
|
88 |
if vertical < horizontal:
|
89 |
if is_vertical(boxes[i]):
|
90 |
-
#
|
91 |
if i in selected_boxes:
|
92 |
selected_boxes.remove(i)
|
93 |
elif vertical > horizontal:
|
94 |
if is_vertical(boxes[i]) == False:
|
95 |
-
#
|
96 |
if i in selected_boxes:
|
97 |
selected_boxes.remove(i)
|
98 |
else:
|
@@ -102,23 +139,21 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
|
|
102 |
scores = scores[selected_boxes]
|
103 |
labels = labels[selected_boxes]
|
104 |
|
105 |
-
#
|
106 |
-
obj_not_too_small = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref
|
107 |
-
obj_not_too_big = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=2, element_ref
|
108 |
|
109 |
selected_object = [i for i in range(len(labels)) if i in obj_not_too_small and i in obj_not_too_big]
|
110 |
|
111 |
-
#selected_object = obj_not_too_small
|
112 |
-
|
113 |
boxes = boxes[selected_object]
|
114 |
scores = scores[selected_object]
|
115 |
labels = labels[selected_object]
|
116 |
|
117 |
-
#
|
118 |
for i in range(len(labels)):
|
119 |
if labels[i] == list(object_dict.values()).index('subProcess'):
|
120 |
labels[i] = list(object_dict.values()).index('task')
|
121 |
-
#
|
122 |
lane_index = [i for i in range(len(labels)) if labels[i] == list(object_dict.values()).index('lane')]
|
123 |
boxes = np.delete(boxes, lane_index, axis=0)
|
124 |
labels = np.delete(labels, lane_index)
|
@@ -137,6 +172,19 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
|
|
137 |
|
138 |
|
139 |
def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, distance_treshold=15):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
model.eval()
|
141 |
with torch.no_grad():
|
142 |
image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
|
@@ -173,7 +221,18 @@ def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, dista
|
|
173 |
|
174 |
return image, prediction
|
175 |
|
|
|
176 |
def mix_predictions(objects_pred, arrow_pred):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
# Initialize the list of lists for keypoints
|
178 |
object_keypoints = []
|
179 |
|
@@ -186,7 +245,7 @@ def mix_predictions(objects_pred, arrow_pred):
|
|
186 |
keypoints = [[0, 0, 0], [0, 0, 0]]
|
187 |
object_keypoints.append(keypoints)
|
188 |
|
189 |
-
#
|
190 |
if len(arrow_pred['boxes']) == 0:
|
191 |
return objects_pred['boxes'], objects_pred['labels'], objects_pred['scores'], object_keypoints
|
192 |
|
@@ -199,6 +258,21 @@ def mix_predictions(objects_pred, arrow_pred):
|
|
199 |
|
200 |
|
201 |
def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.6):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
pool_dict = {}
|
203 |
|
204 |
# Filter out pools with IoU greater than the threshold
|
@@ -265,12 +339,24 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
|
|
265 |
return pool_dict, boxes, labels, scores, keypoints
|
266 |
|
267 |
|
268 |
-
|
269 |
def create_links(keypoints, boxes, labels, class_dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
best_points = []
|
271 |
links = []
|
272 |
for i in range(len(labels)):
|
273 |
-
if labels[i]==list(class_dict.values()).index('sequenceFlow') or labels[i]==list(class_dict.values()).index('messageFlow'):
|
274 |
closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
|
275 |
closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
|
276 |
|
@@ -278,11 +364,11 @@ def create_links(keypoints, boxes, labels, class_dict):
|
|
278 |
best_points.append([point_start, point_end])
|
279 |
links.append([closest1, closest2])
|
280 |
else:
|
281 |
-
best_points.append([None,None])
|
282 |
-
links.append([None,None])
|
283 |
|
284 |
for i in range(len(labels)):
|
285 |
-
if labels[i]==list(class_dict.values()).index('dataAssociation'):
|
286 |
closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
|
287 |
closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
|
288 |
if closest1 is not None and closest2 is not None:
|
@@ -291,7 +377,22 @@ def create_links(keypoints, boxes, labels, class_dict):
|
|
291 |
|
292 |
return links, best_points
|
293 |
|
|
|
294 |
def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
sequence_flow_index = list(class_dict.values()).index('sequenceFlow')
|
296 |
message_flow_index = list(class_dict.values()).index('messageFlow')
|
297 |
data_association_index = list(class_dict.values()).index('dataAssociation')
|
@@ -339,7 +440,21 @@ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
|
|
339 |
return labels, flow_links
|
340 |
|
341 |
|
342 |
-
def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
# Filter out the sizes of events, data objects, and message events
|
344 |
event_indices = [i for i, label in enumerate(labels) if class_dict[label] in element_ref]
|
345 |
event_boxes = [boxes[i] for i in event_indices]
|
@@ -360,7 +475,7 @@ def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, elem
|
|
360 |
kept_indices = []
|
361 |
|
362 |
if mode == "lower" or mode == 'both':
|
363 |
-
#
|
364 |
for idx, (box, label) in enumerate(zip(boxes, labels)):
|
365 |
area = (box[2] - box[0]) * (box[3] - box[1])
|
366 |
if not (area_lower_threshold <= area):
|
@@ -370,7 +485,7 @@ def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, elem
|
|
370 |
kept_indices.append(idx)
|
371 |
|
372 |
if mode == "upper" or mode == 'both':
|
373 |
-
#
|
374 |
for idx, (box, label) in enumerate(zip(boxes, labels)):
|
375 |
if label == list(class_dict.values()).index('pool') or label == list(class_dict.values()).index('lane'):
|
376 |
kept_indices.append(idx)
|
@@ -382,17 +497,31 @@ def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, elem
|
|
382 |
else:
|
383 |
kept_indices.append(idx)
|
384 |
|
385 |
-
|
386 |
return kept_indices
|
387 |
|
388 |
|
389 |
-
|
390 |
def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict, limit_area=10000):
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
delete_pool = []
|
394 |
for pool_index, elements in pool_dict.items():
|
395 |
-
#
|
396 |
if pool_index in bpmn_id:
|
397 |
position = bpmn_id.index(pool_index)
|
398 |
else:
|
@@ -405,11 +534,11 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
|
|
405 |
delete_pool.append(position)
|
406 |
print(f"Pool {pool_index} contains only arrow elements, deleting it")
|
407 |
|
408 |
-
#
|
409 |
if position < len(boxes):
|
410 |
pool = boxes[position]
|
411 |
area = (pool[2] - pool[0]) * (pool[3] - pool[1])
|
412 |
-
if len(pool_dict)>1 and area < limit_area:
|
413 |
delete_pool.append(position)
|
414 |
print(f"Pool {pool_index} is too small, deleting it")
|
415 |
|
@@ -417,34 +546,23 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
|
|
417 |
delete_pool.append(position)
|
418 |
print(f"Pool {position} is vertical, deleting it")
|
419 |
|
420 |
-
|
421 |
delete_elements = []
|
422 |
# Check if there is an arrow that has the same links
|
423 |
for i in range(len(labels)):
|
424 |
-
for j in range(i+1, len(labels)):
|
425 |
if labels[i] == list(class_dict.values()).index('sequenceFlow') and labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
426 |
if links[i] == links[j]:
|
427 |
-
print(f'
|
428 |
if scores[i] > scores[j]:
|
429 |
-
print('
|
430 |
delete_elements.append(j)
|
431 |
else:
|
432 |
-
print('
|
433 |
delete_elements.append(i)
|
434 |
|
435 |
-
#
|
436 |
-
"""tex_pred = st.session_state.text_pred
|
437 |
-
for i in range(len(boxes)):
|
438 |
-
for j in range(len(tex_pred[0])):
|
439 |
-
#check if the box is inside the text box but if the text box is inside the box then it is not a problem
|
440 |
-
if proportion_inside(boxes[i], tex_pred[0][j]) > 0.1:
|
441 |
-
#delete_elements.append(i)
|
442 |
-
print('delete element', i)"""
|
443 |
-
|
444 |
-
|
445 |
-
#concatenate the delete_elements and the delete_pool
|
446 |
delete_elements = delete_elements + delete_pool
|
447 |
-
#
|
448 |
delete_elements = list(set(delete_elements))
|
449 |
|
450 |
boxes = np.delete(boxes, delete_elements, axis=0)
|
@@ -456,74 +574,129 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
|
|
456 |
best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
|
457 |
|
458 |
for i in range(len(delete_pool)):
|
459 |
-
#
|
460 |
pool_index = bpmn_id[delete_pool[i]]
|
461 |
-
#
|
462 |
del pool_dict[pool_index]
|
463 |
|
464 |
bpmn_id = [point for i, point in enumerate(bpmn_id) if i not in delete_elements]
|
465 |
|
466 |
-
#
|
467 |
for pool_index, elements in pool_dict.items():
|
468 |
pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
|
469 |
|
470 |
return boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict
|
471 |
|
|
|
472 |
def give_link_to_element(links, labels):
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
|
483 |
def generate_data(image, boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict):
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
|
505 |
-
def develop_prediction(boxes, labels, scores, keypoints, class_dict):
|
506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
|
508 |
|
509 |
-
bpmn_id, pool_dict = create_BPMN_id(labels,pool_dict)
|
510 |
|
511 |
# Create links between elements
|
512 |
flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
|
513 |
|
514 |
-
#Correct the labels of some
|
515 |
labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
|
516 |
|
517 |
-
#
|
518 |
flow_links = give_link_to_element(flow_links, labels)
|
519 |
|
520 |
-
boxes,labels,scores,keypoints,bpmn_id, flow_links,best_points,pool_dict = last_correction(
|
|
|
|
|
521 |
|
522 |
-
return boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict
|
523 |
|
524 |
-
|
525 |
|
526 |
def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
model_object.eval() # Set the model to evaluation mode
|
528 |
model_arrow.eval() # Set the model to evaluation mode
|
529 |
|
@@ -536,7 +709,9 @@ def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_t
|
|
536 |
|
537 |
boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
|
538 |
|
539 |
-
boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(
|
|
|
|
|
540 |
|
541 |
image = image.permute(1, 2, 0).cpu().numpy()
|
542 |
image = (image * 255).astype(np.uint8)
|
@@ -545,7 +720,22 @@ def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_t
|
|
545 |
|
546 |
return image, data
|
547 |
|
|
|
548 |
def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, model_dict, iou_threshold=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
# Initialize dictionaries to hold per-class counts
|
550 |
class_tp = {cls: 0 for cls in model_dict.values()}
|
551 |
class_fp = {cls: 0 for cls in model_dict.values()}
|
@@ -589,10 +779,25 @@ def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, mo
|
|
589 |
return class_precision, class_recall, class_f1_score
|
590 |
|
591 |
|
592 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
result = 0
|
594 |
reverted = False
|
595 |
-
#
|
596 |
idx = np.where(pred_boxes == pred_box)[0][0]
|
597 |
idx2 = np.where(true_boxes == true_box)[0][0]
|
598 |
|
@@ -615,7 +820,24 @@ def keypoints_mesure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints,
|
|
615 |
|
616 |
return result, reverted
|
617 |
|
|
|
618 |
def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred_keypoints, true_keypoints, iou_threshold=0.5, distance_threshold=5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
tp, fp, fn = 0, 0, 0
|
620 |
key_t, key_f = 0, 0
|
621 |
labels_t, labels_f = 0, 0
|
@@ -630,7 +852,9 @@ def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred
|
|
630 |
iou_val = iou(pred_box, true_box)
|
631 |
if iou_val >= iou_threshold:
|
632 |
if true_keypoints is not None and pred_keypoints is not None:
|
633 |
-
key_result, reverted =
|
|
|
|
|
634 |
key_t += key_result
|
635 |
key_f += 2 - key_result
|
636 |
if reverted:
|
@@ -653,6 +877,21 @@ def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred
|
|
653 |
|
654 |
|
655 |
def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
model.eval()
|
657 |
tp, fp, fn = 0, 0, 0
|
658 |
labels_t, labels_f = 0, 0
|
@@ -690,7 +929,7 @@ def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, dis
|
|
690 |
filtered_labels = []
|
691 |
filtered_keypoints = []
|
692 |
if 'keypoints' not in prediction:
|
693 |
-
#
|
694 |
pred_keypoints = [np.zeros((2, 3)) for _ in range(len(pred_boxes))]
|
695 |
|
696 |
for box, score, label, keypoints in zip(pred_boxes, scores, pred_labels, pred_keypoints):
|
@@ -707,7 +946,8 @@ def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, dis
|
|
707 |
filtered_keypoints = None
|
708 |
true_keypoints = None
|
709 |
tp_img, fp_img, fn_img, labels_t_img, labels_f_img, key_t_img, key_f_img, reverted_img = evaluate_single_image(
|
710 |
-
filtered_boxes, true_boxes, filtered_labels, true_labels, filtered_keypoints, true_keypoints, iou_threshold, distance_threshold
|
|
|
711 |
|
712 |
tp += tp_img
|
713 |
fp += fp_img
|
@@ -720,9 +960,26 @@ def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, dis
|
|
720 |
|
721 |
return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted
|
722 |
|
723 |
-
def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type = 'object'):
|
724 |
|
725 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
|
727 |
labels_precision = labels_t / (labels_t + labels_f) if (labels_t + labels_f) > 0 else 0
|
728 |
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
@@ -738,8 +995,21 @@ def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5,
|
|
738 |
return labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy
|
739 |
|
740 |
|
741 |
-
|
742 |
def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
matched_true_boxes = set()
|
744 |
for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
|
745 |
match_found = False
|
@@ -758,7 +1028,20 @@ def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, tr
|
|
758 |
if idx not in matched_true_boxes:
|
759 |
class_fn[model_dict[true_label]] += 1
|
760 |
|
|
|
761 |
def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshold=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
762 |
model.eval()
|
763 |
with torch.no_grad():
|
764 |
for images, targets_im in tqdm(loader, desc="Testing... "):
|
@@ -788,7 +1071,21 @@ def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshol
|
|
788 |
|
789 |
yield pred_boxes, true_boxes, pred_labels, true_labels
|
790 |
|
|
|
791 |
def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5, iou_threshold=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
792 |
class_tp = {cls: 0 for cls in model_dict.values()}
|
793 |
class_fp = {cls: 0 for cls in model_dict.values()}
|
794 |
class_fn = {cls: 0 for cls in model_dict.values()}
|
@@ -809,4 +1106,4 @@ def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5,
|
|
809 |
class_recall[cls] = recall
|
810 |
class_f1_score[cls] = f1_score
|
811 |
|
812 |
-
return class_precision, class_recall, class_f1_score
|
|
|
9 |
|
10 |
|
11 |
def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
|
12 |
+
"""
|
13 |
+
Perform non-maximum suppression to filter out overlapping bounding boxes.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
- boxes (array): Array of bounding boxes.
|
17 |
+
- scores (array): Array of confidence scores for each bounding box.
|
18 |
+
- labels (array, optional): Array of labels for each bounding box.
|
19 |
+
- iou_threshold (float): Intersection-over-Union threshold to use for filtering.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
- list: Indices of selected boxes after suppression.
|
23 |
+
"""
|
24 |
exception = ['pool', 'lane']
|
25 |
|
26 |
idxs = np.argsort(scores) # Sort the boxes according to their scores in ascending order
|
|
|
52 |
|
53 |
|
54 |
def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distance_treshold=15):
|
55 |
+
"""
|
56 |
+
Correct keypoints that are too close together by adjusting their positions.
|
57 |
+
|
58 |
+
Parameters:
|
59 |
+
- keypoints (array): Array of keypoints.
|
60 |
+
- boxes (array): Array of bounding boxes.
|
61 |
+
- labels (array): Array of labels for each bounding box.
|
62 |
+
- model_dict (dict): Dictionary mapping model labels to indices.
|
63 |
+
- distance_treshold (int): Distance threshold below which keypoints are considered too close.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
- array: Corrected keypoints.
|
67 |
+
"""
|
68 |
for idx, (key1, key2) in enumerate(keypoints):
|
69 |
if labels[idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
70 |
list(model_dict.values()).index('messageFlow'),
|
|
|
74 |
distance = np.linalg.norm(key1[:2] - key2[:2])
|
75 |
if distance < distance_treshold:
|
76 |
print('Key modified for index:', idx)
|
77 |
+
x_new, y_new, x, y = find_other_keypoint(idx, keypoints, boxes)
|
78 |
+
keypoints[idx][0][:2] = [x_new, y_new]
|
79 |
+
keypoints[idx][1][:2] = [x, y]
|
80 |
|
81 |
return keypoints
|
82 |
|
83 |
|
84 |
def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
|
85 |
+
"""
|
86 |
+
Perform object detection prediction using the model.
|
87 |
+
|
88 |
+
Parameters:
|
89 |
+
- model (torch.nn.Module): The object detection model.
|
90 |
+
- image (torch.Tensor): The input image.
|
91 |
+
- score_threshold (float): Score threshold for filtering predictions.
|
92 |
+
- iou_threshold (float): IoU threshold for non-maximum suppression.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
- numpy.array, dict: The processed image and the prediction dictionary containing 'boxes', 'scores', and 'labels'.
|
96 |
+
"""
|
97 |
model.eval()
|
98 |
with torch.no_grad():
|
99 |
image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
|
|
|
110 |
|
111 |
selected_boxes = non_maximum_suppression(boxes, scores, labels=labels, iou_threshold=iou_threshold)
|
112 |
|
113 |
+
# Find orientation of the task by checking the size of all the boxes and delete the ones that are not in the same orientation
|
114 |
vertical = 0
|
115 |
for i in range(len(labels)):
|
116 |
if labels[i] != list(object_dict.values()).index('task'):
|
|
|
124 |
|
125 |
if vertical < horizontal:
|
126 |
if is_vertical(boxes[i]):
|
127 |
+
# Find the element in the list and remove it
|
128 |
if i in selected_boxes:
|
129 |
selected_boxes.remove(i)
|
130 |
elif vertical > horizontal:
|
131 |
if is_vertical(boxes[i]) == False:
|
132 |
+
# Find the element in the list and remove it
|
133 |
if i in selected_boxes:
|
134 |
selected_boxes.remove(i)
|
135 |
else:
|
|
|
139 |
scores = scores[selected_boxes]
|
140 |
labels = labels[selected_boxes]
|
141 |
|
142 |
+
# Find the outlier objects that are too small by the area
|
143 |
+
obj_not_too_small = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref=['event', 'messageEvent'], mode="lower")
|
144 |
+
obj_not_too_big = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=2, element_ref=['task'], mode="upper")
|
145 |
|
146 |
selected_object = [i for i in range(len(labels)) if i in obj_not_too_small and i in obj_not_too_big]
|
147 |
|
|
|
|
|
148 |
boxes = boxes[selected_object]
|
149 |
scores = scores[selected_object]
|
150 |
labels = labels[selected_object]
|
151 |
|
152 |
+
# Modify the label of the sub-process to task
|
153 |
for i in range(len(labels)):
|
154 |
if labels[i] == list(object_dict.values()).index('subProcess'):
|
155 |
labels[i] = list(object_dict.values()).index('task')
|
156 |
+
# Delete all lane and also the value in the labels and scores
|
157 |
lane_index = [i for i in range(len(labels)) if labels[i] == list(object_dict.values()).index('lane')]
|
158 |
boxes = np.delete(boxes, lane_index, axis=0)
|
159 |
labels = np.delete(labels, lane_index)
|
|
|
172 |
|
173 |
|
174 |
def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, distance_treshold=15):
|
175 |
+
"""
|
176 |
+
Perform arrow detection prediction using the model.
|
177 |
+
|
178 |
+
Parameters:
|
179 |
+
- model (torch.nn.Module): The arrow detection model.
|
180 |
+
- image (torch.Tensor): The input image.
|
181 |
+
- score_threshold (float): Score threshold for filtering predictions.
|
182 |
+
- iou_threshold (float): IoU threshold for non-maximum suppression.
|
183 |
+
- distance_treshold (int): Distance threshold for keypoint correction.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
- numpy.array, dict: The processed image and the prediction dictionary containing 'boxes', 'scores', 'labels', and 'keypoints'.
|
187 |
+
"""
|
188 |
model.eval()
|
189 |
with torch.no_grad():
|
190 |
image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
|
|
|
221 |
|
222 |
return image, prediction
|
223 |
|
224 |
+
|
225 |
def mix_predictions(objects_pred, arrow_pred):
|
226 |
+
"""
|
227 |
+
Combine object and arrow predictions into a single set of predictions.
|
228 |
+
|
229 |
+
Parameters:
|
230 |
+
- objects_pred (dict): Object predictions dictionary.
|
231 |
+
- arrow_pred (dict): Arrow predictions dictionary.
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
- tuple: Combined boxes, labels, scores, and keypoints.
|
235 |
+
"""
|
236 |
# Initialize the list of lists for keypoints
|
237 |
object_keypoints = []
|
238 |
|
|
|
245 |
keypoints = [[0, 0, 0], [0, 0, 0]]
|
246 |
object_keypoints.append(keypoints)
|
247 |
|
248 |
+
# Concatenate the two predictions
|
249 |
if len(arrow_pred['boxes']) == 0:
|
250 |
return objects_pred['boxes'], objects_pred['labels'], objects_pred['scores'], object_keypoints
|
251 |
|
|
|
258 |
|
259 |
|
260 |
def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.6):
|
261 |
+
"""
|
262 |
+
Regroup elements by pool based on IoU and proximity.
|
263 |
+
|
264 |
+
Parameters:
|
265 |
+
- boxes (array): Array of bounding boxes.
|
266 |
+
- labels (array): Array of labels for each bounding box.
|
267 |
+
- scores (array): Array of confidence scores for each bounding box.
|
268 |
+
- keypoints (array): Array of keypoints.
|
269 |
+
- class_dict (dict): Dictionary mapping class names to indices.
|
270 |
+
- iou_threshold (float): IoU threshold for grouping.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
- dict: Dictionary grouping elements by pool.
|
274 |
+
- array: Updated arrays of boxes, labels, scores, and keypoints.
|
275 |
+
"""
|
276 |
pool_dict = {}
|
277 |
|
278 |
# Filter out pools with IoU greater than the threshold
|
|
|
339 |
return pool_dict, boxes, labels, scores, keypoints
|
340 |
|
341 |
|
|
|
342 |
def create_links(keypoints, boxes, labels, class_dict):
|
343 |
+
"""
|
344 |
+
Create links between elements based on keypoints.
|
345 |
+
|
346 |
+
Parameters:
|
347 |
+
- keypoints (array): Array of keypoints.
|
348 |
+
- boxes (array): Array of bounding boxes.
|
349 |
+
- labels (array): Array of labels for each bounding box.
|
350 |
+
- class_dict (dict): Dictionary mapping class names to indices.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
- list: List of links between elements.
|
354 |
+
- list: List of best points for each link.
|
355 |
+
"""
|
356 |
best_points = []
|
357 |
links = []
|
358 |
for i in range(len(labels)):
|
359 |
+
if labels[i] == list(class_dict.values()).index('sequenceFlow') or labels[i] == list(class_dict.values()).index('messageFlow'):
|
360 |
closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
|
361 |
closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
|
362 |
|
|
|
364 |
best_points.append([point_start, point_end])
|
365 |
links.append([closest1, closest2])
|
366 |
else:
|
367 |
+
best_points.append([None, None])
|
368 |
+
links.append([None, None])
|
369 |
|
370 |
for i in range(len(labels)):
|
371 |
+
if labels[i] == list(class_dict.values()).index('dataAssociation'):
|
372 |
closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
|
373 |
closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
|
374 |
if closest1 is not None and closest2 is not None:
|
|
|
377 |
|
378 |
return links, best_points
|
379 |
|
380 |
+
|
381 |
def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
|
382 |
+
"""
|
383 |
+
Correct labels based on the relationships between elements and pools.
|
384 |
+
|
385 |
+
Parameters:
|
386 |
+
- boxes (array): Array of bounding boxes.
|
387 |
+
- labels (array): Array of labels for each bounding box.
|
388 |
+
- class_dict (dict): Dictionary mapping class names to indices.
|
389 |
+
- pool_dict (dict): Dictionary grouping elements by pool.
|
390 |
+
- flow_links (list): List of links between elements.
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
- array: Corrected labels.
|
394 |
+
- list: Updated flow links.
|
395 |
+
"""
|
396 |
sequence_flow_index = list(class_dict.values()).index('sequenceFlow')
|
397 |
message_flow_index = list(class_dict.values()).index('messageFlow')
|
398 |
data_association_index = list(class_dict.values()).index('dataAssociation')
|
|
|
440 |
return labels, flow_links
|
441 |
|
442 |
|
443 |
+
def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref=['event', 'messageEvent'], mode="lower"):
|
444 |
+
"""
|
445 |
+
Identify outlier objects based on their area.
|
446 |
+
|
447 |
+
Parameters:
|
448 |
+
- boxes (array): Array of bounding boxes.
|
449 |
+
- labels (array): Array of labels for each bounding box.
|
450 |
+
- class_dict (dict): Dictionary mapping class names to indices.
|
451 |
+
- std_factor (float): Standard deviation factor for determining outliers.
|
452 |
+
- element_ref (list): List of reference elements for calculating area statistics.
|
453 |
+
- mode (str): Mode to identify outliers ('lower', 'upper', or 'both').
|
454 |
+
|
455 |
+
Returns:
|
456 |
+
- list: Indices of kept objects that are not outliers.
|
457 |
+
"""
|
458 |
# Filter out the sizes of events, data objects, and message events
|
459 |
event_indices = [i for i, label in enumerate(labels) if class_dict[label] in element_ref]
|
460 |
event_boxes = [boxes[i] for i in event_indices]
|
|
|
475 |
kept_indices = []
|
476 |
|
477 |
if mode == "lower" or mode == 'both':
|
478 |
+
# Check for objects that could be too small
|
479 |
for idx, (box, label) in enumerate(zip(boxes, labels)):
|
480 |
area = (box[2] - box[0]) * (box[3] - box[1])
|
481 |
if not (area_lower_threshold <= area):
|
|
|
485 |
kept_indices.append(idx)
|
486 |
|
487 |
if mode == "upper" or mode == 'both':
|
488 |
+
# Check for objects that could be too big
|
489 |
for idx, (box, label) in enumerate(zip(boxes, labels)):
|
490 |
if label == list(class_dict.values()).index('pool') or label == list(class_dict.values()).index('lane'):
|
491 |
kept_indices.append(idx)
|
|
|
497 |
else:
|
498 |
kept_indices.append(idx)
|
499 |
|
|
|
500 |
return kept_indices
|
501 |
|
502 |
|
|
|
503 |
def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict, limit_area=10000):
|
504 |
+
"""
|
505 |
+
Perform final corrections on the predictions by deleting irrelevant or small pools and duplicate elements.
|
506 |
+
|
507 |
+
Parameters:
|
508 |
+
- boxes (array): Array of bounding boxes.
|
509 |
+
- labels (array): Array of labels for each bounding box.
|
510 |
+
- scores (array): Array of confidence scores for each bounding box.
|
511 |
+
- keypoints (array): Array of keypoints.
|
512 |
+
- bpmn_id (list): List of BPMN IDs.
|
513 |
+
- links (list): List of links between elements.
|
514 |
+
- best_points (list): List of best points for each link.
|
515 |
+
- pool_dict (dict): Dictionary grouping elements by pool.
|
516 |
+
- limit_area (int): Minimum area threshold for pools.
|
517 |
+
|
518 |
+
Returns:
|
519 |
+
- tuple: Corrected arrays of boxes, labels, scores, keypoints, BPMN IDs, links, best points, and pool dictionary.
|
520 |
+
"""
|
521 |
+
# Delete pools that have only messageFlow on it
|
522 |
delete_pool = []
|
523 |
for pool_index, elements in pool_dict.items():
|
524 |
+
# Find the position of the pool_index in the bpmn_id
|
525 |
if pool_index in bpmn_id:
|
526 |
position = bpmn_id.index(pool_index)
|
527 |
else:
|
|
|
534 |
delete_pool.append(position)
|
535 |
print(f"Pool {pool_index} contains only arrow elements, deleting it")
|
536 |
|
537 |
+
# Calculate the area of the pool
|
538 |
if position < len(boxes):
|
539 |
pool = boxes[position]
|
540 |
area = (pool[2] - pool[0]) * (pool[3] - pool[1])
|
541 |
+
if len(pool_dict) > 1 and area < limit_area:
|
542 |
delete_pool.append(position)
|
543 |
print(f"Pool {pool_index} is too small, deleting it")
|
544 |
|
|
|
546 |
delete_pool.append(position)
|
547 |
print(f"Pool {position} is vertical, deleting it")
|
548 |
|
|
|
549 |
delete_elements = []
|
550 |
# Check if there is an arrow that has the same links
|
551 |
for i in range(len(labels)):
|
552 |
+
for j in range(i + 1, len(labels)):
|
553 |
if labels[i] == list(class_dict.values()).index('sequenceFlow') and labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
554 |
if links[i] == links[j]:
|
555 |
+
print(f'Element {i} and {j} have the same links')
|
556 |
if scores[i] > scores[j]:
|
557 |
+
print('Delete element', j)
|
558 |
delete_elements.append(j)
|
559 |
else:
|
560 |
+
print('Delete element', i)
|
561 |
delete_elements.append(i)
|
562 |
|
563 |
+
# Concatenate the delete_elements and the delete_pool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
delete_elements = delete_elements + delete_pool
|
565 |
+
# Delete double value in delete_elements
|
566 |
delete_elements = list(set(delete_elements))
|
567 |
|
568 |
boxes = np.delete(boxes, delete_elements, axis=0)
|
|
|
574 |
best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
|
575 |
|
576 |
for i in range(len(delete_pool)):
|
577 |
+
# Find the bpmn_id of the pool
|
578 |
pool_index = bpmn_id[delete_pool[i]]
|
579 |
+
# Delete the pool_index in pool_dict
|
580 |
del pool_dict[pool_index]
|
581 |
|
582 |
bpmn_id = [point for i, point in enumerate(bpmn_id) if i not in delete_elements]
|
583 |
|
584 |
+
# Also delete the element in the pool_dict
|
585 |
for pool_index, elements in pool_dict.items():
|
586 |
pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
|
587 |
|
588 |
return boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict
|
589 |
|
590 |
+
|
591 |
def give_link_to_element(links, labels):
|
592 |
+
"""
|
593 |
+
Assign links to elements to create BPMN IDs for events.
|
594 |
+
|
595 |
+
Parameters:
|
596 |
+
- links (list): List of links between elements.
|
597 |
+
- labels (array): Array of labels for each bounding box.
|
598 |
+
|
599 |
+
Returns:
|
600 |
+
- list: Updated list of links with assigned links for events.
|
601 |
+
"""
|
602 |
+
# Give a link to event to allow the creation of the BPMN ID with start, intermediate, and end event
|
603 |
+
for i in range(len(links)):
|
604 |
+
if labels[i] == list(class_dict.values()).index('sequenceFlow'):
|
605 |
+
id1, id2 = links[i]
|
606 |
+
if (id1 and id2) is not None:
|
607 |
+
links[id1][1] = i
|
608 |
+
links[id2][0] = i
|
609 |
+
return links
|
610 |
|
611 |
|
612 |
def generate_data(image, boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict):
|
613 |
+
"""
|
614 |
+
Generate a data dictionary containing image and prediction information.
|
615 |
+
|
616 |
+
Parameters:
|
617 |
+
- image (numpy.array): The input image.
|
618 |
+
- boxes (array): Array of bounding boxes.
|
619 |
+
- labels (array): Array of labels for each bounding box.
|
620 |
+
- scores (array): Array of confidence scores for each bounding box.
|
621 |
+
- keypoints (array): Array of keypoints.
|
622 |
+
- bpmn_id (list): List of BPMN IDs.
|
623 |
+
- flow_links (list): List of links between elements.
|
624 |
+
- best_points (list): List of best points for each link.
|
625 |
+
- pool_dict (dict): Dictionary grouping elements by pool.
|
626 |
+
|
627 |
+
Returns:
|
628 |
+
- dict: Data dictionary containing all prediction information.
|
629 |
+
"""
|
630 |
+
idx = []
|
631 |
+
for i in range(len(labels)):
|
632 |
+
idx.append(i)
|
633 |
+
|
634 |
+
data = {
|
635 |
+
'image': image,
|
636 |
+
'idx': idx,
|
637 |
+
'boxes': boxes,
|
638 |
+
'labels': labels,
|
639 |
+
'scores': scores,
|
640 |
+
'keypoints': keypoints,
|
641 |
+
'links': flow_links,
|
642 |
+
'best_points': best_points,
|
643 |
+
'pool_dict': pool_dict,
|
644 |
+
'BPMN_id': bpmn_id,
|
645 |
+
}
|
646 |
+
|
647 |
+
return data
|
648 |
|
|
|
649 |
|
650 |
+
def develop_prediction(boxes, labels, scores, keypoints, class_dict):
|
651 |
+
"""
|
652 |
+
Develop predictions by regrouping elements, creating links, and correcting labels.
|
653 |
+
|
654 |
+
Parameters:
|
655 |
+
- boxes (array): Array of bounding boxes.
|
656 |
+
- labels (array): Array of labels for each bounding box.
|
657 |
+
- scores (array): Array of confidence scores for each bounding box.
|
658 |
+
- keypoints (array): Array of keypoints.
|
659 |
+
- class_dict (dict): Dictionary mapping class names to indices.
|
660 |
+
|
661 |
+
Returns:
|
662 |
+
- tuple: Developed prediction components including boxes, labels, scores, keypoints, BPMN IDs, flow links, best points, and pool dictionary.
|
663 |
+
"""
|
664 |
pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
|
665 |
|
666 |
+
bpmn_id, pool_dict = create_BPMN_id(labels, pool_dict)
|
667 |
|
668 |
# Create links between elements
|
669 |
flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
|
670 |
|
671 |
+
# Correct the labels of some sequenceFlow that cross multiple pools
|
672 |
labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
|
673 |
|
674 |
+
# Give a link to event to allow the creation of the BPMN ID with start, intermediate, and end event
|
675 |
flow_links = give_link_to_element(flow_links, labels)
|
676 |
|
677 |
+
boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = last_correction(
|
678 |
+
boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict
|
679 |
+
)
|
680 |
|
681 |
+
return boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict
|
682 |
|
|
|
683 |
|
684 |
def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15):
|
685 |
+
"""
|
686 |
+
Perform a full prediction by combining object and arrow models and generating data.
|
687 |
+
|
688 |
+
Parameters:
|
689 |
+
- model_object (torch.nn.Module): The object detection model.
|
690 |
+
- model_arrow (torch.nn.Module): The arrow detection model.
|
691 |
+
- image (torch.Tensor): The input image.
|
692 |
+
- score_threshold (float): Score threshold for filtering predictions.
|
693 |
+
- iou_threshold (float): IoU threshold for non-maximum suppression.
|
694 |
+
- resize (bool): Flag indicating whether to resize the image.
|
695 |
+
- distance_treshold (int): Distance threshold for keypoint correction.
|
696 |
+
|
697 |
+
Returns:
|
698 |
+
- numpy.array, dict: The processed image and the data dictionary containing prediction information.
|
699 |
+
"""
|
700 |
model_object.eval() # Set the model to evaluation mode
|
701 |
model_arrow.eval() # Set the model to evaluation mode
|
702 |
|
|
|
709 |
|
710 |
boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
|
711 |
|
712 |
+
boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(
|
713 |
+
boxes, labels, scores, keypoints, class_dict
|
714 |
+
)
|
715 |
|
716 |
image = image.permute(1, 2, 0).cpu().numpy()
|
717 |
image = (image * 255).astype(np.uint8)
|
|
|
720 |
|
721 |
return image, data
|
722 |
|
723 |
+
|
724 |
def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, model_dict, iou_threshold=0.5):
|
725 |
+
"""
|
726 |
+
Evaluate the model's performance on a per-class basis.
|
727 |
+
|
728 |
+
Parameters:
|
729 |
+
- pred_boxes (array): Predicted bounding boxes.
|
730 |
+
- true_boxes (array): Ground truth bounding boxes.
|
731 |
+
- pred_labels (array): Predicted labels.
|
732 |
+
- true_labels (array): Ground truth labels.
|
733 |
+
- model_dict (dict): Dictionary mapping model labels to indices.
|
734 |
+
- iou_threshold (float): IoU threshold for determining matches.
|
735 |
+
|
736 |
+
Returns:
|
737 |
+
- tuple: Precision, recall, and F1-score per class.
|
738 |
+
"""
|
739 |
# Initialize dictionaries to hold per-class counts
|
740 |
class_tp = {cls: 0 for cls in model_dict.values()}
|
741 |
class_fp = {cls: 0 for cls in model_dict.values()}
|
|
|
779 |
return class_precision, class_recall, class_f1_score
|
780 |
|
781 |
|
782 |
+
def keypoints_measure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold=5):
|
783 |
+
"""
|
784 |
+
Measure the accuracy of predicted keypoints compared to true keypoints.
|
785 |
+
|
786 |
+
Parameters:
|
787 |
+
- pred_boxes (array): Predicted bounding boxes.
|
788 |
+
- pred_box (array): Single predicted bounding box.
|
789 |
+
- true_boxes (array): Ground truth bounding boxes.
|
790 |
+
- true_box (array): Single ground truth bounding box.
|
791 |
+
- pred_keypoints (array): Predicted keypoints.
|
792 |
+
- true_keypoints (array): Ground truth keypoints.
|
793 |
+
- distance_threshold (int): Distance threshold for considering a keypoint match.
|
794 |
+
|
795 |
+
Returns:
|
796 |
+
- tuple: Number of correct keypoints and whether the keypoints are reverted.
|
797 |
+
"""
|
798 |
result = 0
|
799 |
reverted = False
|
800 |
+
# Find the position of keypoints in the list
|
801 |
idx = np.where(pred_boxes == pred_box)[0][0]
|
802 |
idx2 = np.where(true_boxes == true_box)[0][0]
|
803 |
|
|
|
820 |
|
821 |
return result, reverted
|
822 |
|
823 |
+
|
824 |
def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred_keypoints, true_keypoints, iou_threshold=0.5, distance_threshold=5):
|
825 |
+
"""
|
826 |
+
Evaluate a single image's predictions against the ground truth.
|
827 |
+
|
828 |
+
Parameters:
|
829 |
+
- pred_boxes (array): Predicted bounding boxes.
|
830 |
+
- true_boxes (array): Ground truth bounding boxes.
|
831 |
+
- pred_labels (array): Predicted labels.
|
832 |
+
- true_labels (array): Ground truth labels.
|
833 |
+
- pred_keypoints (array): Predicted keypoints.
|
834 |
+
- true_keypoints (array): Ground truth keypoints.
|
835 |
+
- iou_threshold (float): IoU threshold for determining matches.
|
836 |
+
- distance_threshold (int): Distance threshold for considering a keypoint match.
|
837 |
+
|
838 |
+
Returns:
|
839 |
+
- tuple: True positives, false positives, false negatives, correct labels, incorrect labels, correct keypoints, incorrect keypoints, and reverted keypoints count.
|
840 |
+
"""
|
841 |
tp, fp, fn = 0, 0, 0
|
842 |
key_t, key_f = 0, 0
|
843 |
labels_t, labels_f = 0, 0
|
|
|
852 |
iou_val = iou(pred_box, true_box)
|
853 |
if iou_val >= iou_threshold:
|
854 |
if true_keypoints is not None and pred_keypoints is not None:
|
855 |
+
key_result, reverted = keypoints_measure(
|
856 |
+
pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold
|
857 |
+
)
|
858 |
key_t += key_result
|
859 |
key_f += 2 - key_result
|
860 |
if reverted:
|
|
|
877 |
|
878 |
|
879 |
def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'):
|
880 |
+
"""
|
881 |
+
Evaluate the model on a dataset using predictions for evaluation.
|
882 |
+
|
883 |
+
Parameters:
|
884 |
+
- model (torch.nn.Module): The model to evaluate.
|
885 |
+
- loader (torch.utils.data.DataLoader): DataLoader for the dataset.
|
886 |
+
- score_threshold (float): Score threshold for filtering predictions.
|
887 |
+
- iou_threshold (float): IoU threshold for determining matches.
|
888 |
+
- distance_threshold (int): Distance threshold for considering a keypoint match.
|
889 |
+
- key_correction (bool): Whether to apply keypoint correction.
|
890 |
+
- model_type (str): Type of model ('object' or 'arrow').
|
891 |
+
|
892 |
+
Returns:
|
893 |
+
- tuple: Evaluation results including true positives, false positives, false negatives, correct labels, incorrect labels, correct keypoints, incorrect keypoints, and reverted keypoints count.
|
894 |
+
"""
|
895 |
model.eval()
|
896 |
tp, fp, fn = 0, 0, 0
|
897 |
labels_t, labels_f = 0, 0
|
|
|
929 |
filtered_labels = []
|
930 |
filtered_keypoints = []
|
931 |
if 'keypoints' not in prediction:
|
932 |
+
# Create a list of zeros of length equal to the number of boxes
|
933 |
pred_keypoints = [np.zeros((2, 3)) for _ in range(len(pred_boxes))]
|
934 |
|
935 |
for box, score, label, keypoints in zip(pred_boxes, scores, pred_labels, pred_keypoints):
|
|
|
946 |
filtered_keypoints = None
|
947 |
true_keypoints = None
|
948 |
tp_img, fp_img, fn_img, labels_t_img, labels_f_img, key_t_img, key_f_img, reverted_img = evaluate_single_image(
|
949 |
+
filtered_boxes, true_boxes, filtered_labels, true_labels, filtered_keypoints, true_keypoints, iou_threshold, distance_threshold
|
950 |
+
)
|
951 |
|
952 |
tp += tp_img
|
953 |
fp += fp_img
|
|
|
960 |
|
961 |
return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted
|
962 |
|
|
|
963 |
|
964 |
+
def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'):
|
965 |
+
"""
|
966 |
+
Main function to evaluate the model on the test dataset.
|
967 |
+
|
968 |
+
Parameters:
|
969 |
+
- model (torch.nn.Module): The model to evaluate.
|
970 |
+
- test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
|
971 |
+
- score_threshold (float): Score threshold for filtering predictions.
|
972 |
+
- iou_threshold (float): IoU threshold for determining matches.
|
973 |
+
- distance_threshold (int): Distance threshold for considering a keypoint match.
|
974 |
+
- key_correction (bool): Whether to apply keypoint correction.
|
975 |
+
- model_type (str): Type of model ('object' or 'arrow').
|
976 |
+
|
977 |
+
Returns:
|
978 |
+
- tuple: Precision, recall, F1-score, key accuracy, and reverted accuracy.
|
979 |
+
"""
|
980 |
+
tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted = pred_4_evaluation(
|
981 |
+
model, test_loader, score_threshold, iou_threshold, distance_threshold, key_correction, model_type
|
982 |
+
)
|
983 |
|
984 |
labels_precision = labels_t / (labels_t + labels_f) if (labels_t + labels_f) > 0 else 0
|
985 |
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
|
995 |
return labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy
|
996 |
|
997 |
|
|
|
998 |
def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold=0.5):
|
999 |
+
"""
|
1000 |
+
Evaluate a single image's predictions on a per-class basis.
|
1001 |
+
|
1002 |
+
Parameters:
|
1003 |
+
- pred_boxes (array): Predicted bounding boxes.
|
1004 |
+
- true_boxes (array): Ground truth bounding boxes.
|
1005 |
+
- pred_labels (array): Predicted labels.
|
1006 |
+
- true_labels (array): Ground truth labels.
|
1007 |
+
- class_tp (dict): Dictionary of true positive counts per class.
|
1008 |
+
- class_fp (dict): Dictionary of false positive counts per class.
|
1009 |
+
- class_fn (dict): Dictionary of false negative counts per class.
|
1010 |
+
- model_dict (dict): Dictionary mapping model labels to indices.
|
1011 |
+
- iou_threshold (float): IoU threshold for determining matches.
|
1012 |
+
"""
|
1013 |
matched_true_boxes = set()
|
1014 |
for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
|
1015 |
match_found = False
|
|
|
1028 |
if idx not in matched_true_boxes:
|
1029 |
class_fn[model_dict[true_label]] += 1
|
1030 |
|
1031 |
+
|
1032 |
def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshold=0.5):
|
1033 |
+
"""
|
1034 |
+
Generate predictions for evaluation on a per-class basis.
|
1035 |
+
|
1036 |
+
Parameters:
|
1037 |
+
- model (torch.nn.Module): The model to evaluate.
|
1038 |
+
- loader (torch.utils.data.DataLoader): DataLoader for the dataset.
|
1039 |
+
- score_threshold (float): Score threshold for filtering predictions.
|
1040 |
+
- iou_threshold (float): IoU threshold for determining matches.
|
1041 |
+
|
1042 |
+
Yields:
|
1043 |
+
- tuple: Predicted and true boxes and labels for each batch.
|
1044 |
+
"""
|
1045 |
model.eval()
|
1046 |
with torch.no_grad():
|
1047 |
for images, targets_im in tqdm(loader, desc="Testing... "):
|
|
|
1071 |
|
1072 |
yield pred_boxes, true_boxes, pred_labels, true_labels
|
1073 |
|
1074 |
+
|
1075 |
def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5, iou_threshold=0.5):
|
1076 |
+
"""
|
1077 |
+
Evaluate the model's performance on a per-class basis for the entire dataset.
|
1078 |
+
|
1079 |
+
Parameters:
|
1080 |
+
- model (torch.nn.Module): The model to evaluate.
|
1081 |
+
- test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
|
1082 |
+
- model_dict (dict): Dictionary mapping model labels to indices.
|
1083 |
+
- score_threshold (float): Score threshold for filtering predictions.
|
1084 |
+
- iou_threshold (float): IoU threshold for determining matches.
|
1085 |
+
|
1086 |
+
Returns:
|
1087 |
+
- tuple: Precision, recall, and F1-score per class.
|
1088 |
+
"""
|
1089 |
class_tp = {cls: 0 for cls in model_dict.values()}
|
1090 |
class_fp = {cls: 0 for cls in model_dict.values()}
|
1091 |
class_fn = {cls: 0 for cls in model_dict.values()}
|
|
|
1106 |
class_recall[cls] = recall
|
1107 |
class_f1_score[cls] = f1_score
|
1108 |
|
1109 |
+
return class_precision, class_recall, class_f1_score
|
modules/streamlit_utils.py
CHANGED
@@ -15,46 +15,64 @@ from modules.display import draw_stream
|
|
15 |
from modules.eval import full_prediction
|
16 |
from modules.train import get_faster_rcnn_model, get_arrow_model
|
17 |
from streamlit_image_comparison import image_comparison
|
18 |
-
|
19 |
from streamlit_image_annotation import detection
|
20 |
from modules.toXML import create_XML
|
21 |
from modules.eval import develop_prediction, generate_data
|
22 |
from modules.utils import class_dict, object_dict
|
23 |
-
|
24 |
from modules.htlm_webpage import display_bpmn_xml
|
25 |
from streamlit_cropper import st_cropper
|
26 |
from streamlit_image_select import image_select
|
27 |
from streamlit_js_eval import streamlit_js_eval
|
28 |
-
|
29 |
from modules.toWizard import create_wizard_file
|
30 |
from huggingface_hub import hf_hub_download
|
31 |
import time
|
32 |
-
|
33 |
from modules.toXML import get_size_elements
|
34 |
|
35 |
-
|
36 |
def get_memory_usage():
|
|
|
|
|
|
|
37 |
process = psutil.Process()
|
38 |
mem_info = process.memory_info()
|
39 |
return mem_info.rss / (1024 ** 2) # Return memory usage in MB
|
40 |
|
|
|
41 |
def clear_memory():
|
|
|
|
|
|
|
42 |
st.session_state.clear()
|
43 |
gc.collect()
|
44 |
|
45 |
-
|
46 |
# Function to read XML content from a file
|
47 |
def read_xml_file(filepath):
|
48 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
with open(filepath, 'r', encoding='utf-8') as file:
|
50 |
return file.read()
|
51 |
|
52 |
-
|
53 |
# Suppress the symlink warning
|
54 |
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
|
55 |
|
56 |
# Function to load the models only once and use session state to keep track of it
|
57 |
def load_models():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
with st.spinner('Loading model...'):
|
59 |
model_object = get_faster_rcnn_model(len(object_dict))
|
60 |
model_arrow = get_arrow_model(len(arrow_dict), 2)
|
@@ -71,7 +89,6 @@ def load_models():
|
|
71 |
|
72 |
# Load model arrow
|
73 |
if not Path(output_arrow).exists():
|
74 |
-
# Download model from Hugging Face Hub
|
75 |
model_arrow.load_state_dict(torch.load(model_arrow_path, map_location=device))
|
76 |
st.session_state.model_arrow = model_arrow
|
77 |
print('Model arrow downloaded from Hugging Face Hub')
|
@@ -82,22 +99,18 @@ def load_models():
|
|
82 |
print()
|
83 |
st.session_state.model_arrow = model_arrow
|
84 |
print('Model arrow loaded from local file')
|
85 |
-
|
86 |
|
87 |
# Load model object
|
88 |
if not Path(output_object).exists():
|
89 |
-
# Download model from Hugging Face Hub
|
90 |
model_object.load_state_dict(torch.load(model_object_path, map_location=device))
|
91 |
st.session_state.model_object = model_object
|
92 |
print('Model object downloaded from Hugging Face Hub')
|
93 |
-
# Save the model locally
|
94 |
torch.save(model_object.state_dict(), output_object)
|
95 |
elif 'model_object' not in st.session_state and Path(output_object).exists():
|
96 |
model_object.load_state_dict(torch.load(output_object, map_location=device))
|
97 |
print()
|
98 |
st.session_state.model_object = model_object
|
99 |
-
print('Model object loaded from local file
|
100 |
-
|
101 |
|
102 |
# Move models to device
|
103 |
model_arrow.to(device)
|
@@ -110,6 +123,17 @@ def load_models():
|
|
110 |
|
111 |
# Function to prepare the image for processing
|
112 |
def prepare_image(image, pad=True, new_size=(1333, 1333)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
original_size = image.size
|
114 |
# Calculate scale to fit the new size while maintaining aspect ratio
|
115 |
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
|
@@ -128,6 +152,15 @@ def prepare_image(image, pad=True, new_size=(1333, 1333)):
|
|
128 |
|
129 |
# Function to display various options for image annotation
|
130 |
def display_options(image, score_threshold, is_mobile, screen_width):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
col1, col2, col3, col4, col5 = st.columns(5)
|
132 |
with col1:
|
133 |
write_class = st.toggle("Write Class", value=True)
|
@@ -157,7 +190,7 @@ def display_options(image, score_threshold, is_mobile, screen_width):
|
|
157 |
if is_mobile is True:
|
158 |
width = screen_width
|
159 |
else:
|
160 |
-
width = screen_width//2
|
161 |
|
162 |
# Display the original and annotated images side by side
|
163 |
image_comparison(
|
@@ -171,8 +204,25 @@ def display_options(image, score_threshold, is_mobile, screen_width):
|
|
171 |
|
172 |
# Function to perform inference on the uploaded image using the loaded models
|
173 |
def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
uploaded_image = prepare_image(image, pad=False)
|
175 |
-
|
176 |
img_tensor = F.to_tensor(prepare_image(image.convert('RGB')))
|
177 |
|
178 |
# Display original image
|
@@ -181,7 +231,7 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
|
|
181 |
if is_mobile is False:
|
182 |
width = screen_width
|
183 |
if is_mobile is False:
|
184 |
-
width = screen_width//2
|
185 |
image_placeholder.image(uploaded_image, caption='Original Image', width=width)
|
186 |
|
187 |
# Perform OCR on the uploaded image
|
@@ -193,9 +243,9 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
|
|
193 |
# Prediction
|
194 |
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
|
195 |
|
196 |
-
#Mapping text to prediction
|
197 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
198 |
-
|
199 |
# Remove the original image display
|
200 |
image_placeholder.empty()
|
201 |
|
@@ -204,24 +254,44 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
|
|
204 |
|
205 |
return image, st.session_state.prediction, st.session_state.text_mapping
|
206 |
|
|
|
207 |
@st.cache_data
|
208 |
def get_image(uploaded_file):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
return Image.open(uploaded_file).convert('RGB')
|
210 |
|
211 |
-
|
212 |
def configure_page():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
st.set_page_config(layout="wide")
|
214 |
screen_width = streamlit_js_eval(js_expressions='screen.width', want_output=True, key='SCR')
|
215 |
is_mobile = screen_width is not None and screen_width < 800
|
216 |
return is_mobile, screen_width
|
217 |
|
|
|
218 |
def display_banner(is_mobile):
|
219 |
-
# JavaScript expression to detect dark mode
|
220 |
-
dark_mode_js = """
|
221 |
-
(window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches)
|
222 |
"""
|
|
|
223 |
|
224 |
-
|
|
|
|
|
|
|
225 |
is_dark_mode = streamlit_js_eval(js_expressions=dark_mode_js, key='dark_mode')
|
226 |
|
227 |
if is_mobile:
|
@@ -235,16 +305,27 @@ def display_banner(is_mobile):
|
|
235 |
else:
|
236 |
st.image("./images/banner_desktop.png", use_column_width=True)
|
237 |
|
|
|
238 |
def display_title(is_mobile):
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
title = "Welcome on the BPMN AI model recognition app"
|
240 |
if is_mobile:
|
241 |
title = "Welcome on the mobile version of BPMN AI model recognition app"
|
242 |
st.title(title)
|
243 |
|
|
|
244 |
def display_sidebar():
|
|
|
|
|
|
|
245 |
st.sidebar.header("This BPMN AI model recognition is proposed by: \n ELCA in collaboration with EPFL.")
|
246 |
st.sidebar.subheader("Instructions:")
|
247 |
-
st.sidebar.text("1. Upload
|
248 |
st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
|
249 |
st.sidebar.text("3. Set the score threshold for\n prediction (default is 0.5)")
|
250 |
st.sidebar.text("4. Click on 'Launch Prediction'")
|
@@ -252,20 +333,20 @@ def display_sidebar():
|
|
252 |
st.sidebar.text("6. You can modify the result \n by clicking on:\n 'Method&Style modification'")
|
253 |
st.sidebar.text("7. You can change the scale for \n the XML file and the size of \n elements (default is 1.0)")
|
254 |
st.sidebar.text("8. You can modify with modeler \n and download the result in \n right format")
|
255 |
-
|
256 |
st.sidebar.subheader("If there is an error, try to:")
|
257 |
st.sidebar.text("1. Change the score threshold")
|
258 |
st.sidebar.text("2. Re-crop the image by placing\n the BPMN diagram in the\n center of the image")
|
259 |
st.sidebar.text("3. Re-Launch the prediction")
|
260 |
-
|
261 |
st.sidebar.subheader("You can close this sidebar")
|
262 |
-
|
263 |
for i in range(5):
|
264 |
st.sidebar.subheader("")
|
265 |
-
|
266 |
st.sidebar.subheader("Made with ❤️ by Benjamin.K")
|
267 |
|
|
|
268 |
def initialize_session_state():
|
|
|
|
|
|
|
269 |
if 'pool_bboxes' not in st.session_state:
|
270 |
st.session_state.pool_bboxes = []
|
271 |
if 'model_loaded' not in st.session_state:
|
@@ -275,7 +356,14 @@ def initialize_session_state():
|
|
275 |
load_models()
|
276 |
st.rerun()
|
277 |
|
|
|
278 |
def load_example_image():
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
with st.expander("Use example images"):
|
280 |
img_selected = image_select(
|
281 |
"If you have no image and just want to test the demo, click on one of these images",
|
@@ -287,10 +375,20 @@ def load_example_image():
|
|
287 |
)
|
288 |
return img_selected
|
289 |
|
|
|
290 |
def load_user_image(img_selected, is_mobile):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
if img_selected == './images/none.jpg':
|
292 |
img_selected = None
|
293 |
-
|
294 |
if img_selected is not None:
|
295 |
uploaded_file = img_selected
|
296 |
else:
|
@@ -300,13 +398,23 @@ def load_user_image(img_selected, is_mobile):
|
|
300 |
col1, col2 = st.columns(2)
|
301 |
with col1:
|
302 |
uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
|
303 |
-
|
304 |
return uploaded_file
|
305 |
|
|
|
306 |
def display_image(uploaded_file, screen_width, is_mobile):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
if 'rotation_angle' not in st.session_state:
|
308 |
st.session_state.rotation_angle = 0 # Initialize the rotation angle in session state
|
309 |
-
|
310 |
if 'brightness' not in st.session_state:
|
311 |
st.session_state.brightness = 1.0 # Initialize brightness in session state
|
312 |
|
@@ -349,15 +457,23 @@ def display_image(uploaded_file, screen_width, is_mobile):
|
|
349 |
if not is_mobile:
|
350 |
cropped_image = crop_image(adjusted_image, original_image)
|
351 |
else:
|
352 |
-
st.image(adjusted_image, caption="Image", use_column_width=False, width=int(4/5 * screen_width))
|
353 |
cropped_image = original_image
|
354 |
|
355 |
return cropped_image
|
356 |
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
def crop_image(resized_image, original_image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
marge = 10
|
362 |
cropped_box = st_cropper(
|
363 |
resized_image,
|
@@ -373,23 +489,50 @@ def crop_image(resized_image, original_image):
|
|
373 |
cropped_image = original_image.crop((x0, y0, x1, y1))
|
374 |
return cropped_image
|
375 |
|
|
|
376 |
def get_score_threshold(is_mobile):
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
col1, col2 = st.columns(2)
|
378 |
with col1:
|
379 |
-
st.session_state.score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
|
380 |
|
381 |
def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
st.session_state.crop_image = cropped_image
|
383 |
with st.spinner('Processing...'):
|
384 |
-
image, _
|
385 |
st.session_state.model_object, st.session_state.model_arrow, st.session_state.crop_image,
|
386 |
score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
|
387 |
)
|
388 |
st.balloons()
|
389 |
return image
|
390 |
-
|
391 |
|
392 |
def modify_results(percentage_text_dist_thresh=0.5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
with st.expander("Method & Style modification"):
|
394 |
label_list = list(object_dict.values())
|
395 |
if st.session_state.prediction['labels'][-1] == 6:
|
@@ -445,7 +588,6 @@ def modify_results(percentage_text_dist_thresh=0.5):
|
|
445 |
|
446 |
object_labels = np.array(object_labels)
|
447 |
|
448 |
-
|
449 |
if len(object_bboxes) == len(bboxes):
|
450 |
# Calculate absolute differences
|
451 |
abs_diff = np.abs(object_bboxes - bboxes)
|
@@ -456,7 +598,7 @@ def modify_results(percentage_text_dist_thresh=0.5):
|
|
456 |
changes = True
|
457 |
break
|
458 |
|
459 |
-
#
|
460 |
if not np.array_equal(object_labels, new_lab):
|
461 |
changes = True
|
462 |
else:
|
@@ -477,7 +619,6 @@ def modify_results(percentage_text_dist_thresh=0.5):
|
|
477 |
new_scores = np.concatenate((object_scores, arrow_score))
|
478 |
new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
|
479 |
|
480 |
-
|
481 |
boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict)
|
482 |
|
483 |
st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
|
@@ -489,21 +630,35 @@ def modify_results(percentage_text_dist_thresh=0.5):
|
|
489 |
|
490 |
return True
|
491 |
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
def display_bpmn_modeler(is_mobile, screen_width):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
with st.spinner('Waiting for BPMN modeler...'):
|
497 |
st.session_state.bpmn_xml = create_XML(
|
498 |
st.session_state.prediction.copy(), st.session_state.text_mapping,
|
499 |
st.session_state.size_scale, st.session_state.scale
|
500 |
)
|
501 |
-
st.session_state.vizi_file = create_wizard_file(st.session_state.prediction.copy(), st.session_state.text_mapping)
|
502 |
|
|
|
|
|
503 |
display_bpmn_xml(st.session_state.bpmn_xml, st.session_state.vizi_file, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
|
504 |
|
505 |
-
|
506 |
def find_best_scale(pred, size_elements):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
boxes = pred['boxes']
|
508 |
labels = pred['labels']
|
509 |
|
@@ -535,6 +690,12 @@ def find_best_scale(pred, size_elements):
|
|
535 |
return best_scale
|
536 |
|
537 |
def modeler_options(is_mobile):
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
if not is_mobile:
|
539 |
with st.expander("Options for BPMN modeler"):
|
540 |
col1, col2 = st.columns(2)
|
@@ -545,4 +706,4 @@ def modeler_options(is_mobile):
|
|
545 |
st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
|
546 |
else:
|
547 |
st.session_state.scale = 1.0
|
548 |
-
st.session_state.size_scale = 1.0
|
|
|
15 |
from modules.eval import full_prediction
|
16 |
from modules.train import get_faster_rcnn_model, get_arrow_model
|
17 |
from streamlit_image_comparison import image_comparison
|
|
|
18 |
from streamlit_image_annotation import detection
|
19 |
from modules.toXML import create_XML
|
20 |
from modules.eval import develop_prediction, generate_data
|
21 |
from modules.utils import class_dict, object_dict
|
|
|
22 |
from modules.htlm_webpage import display_bpmn_xml
|
23 |
from streamlit_cropper import st_cropper
|
24 |
from streamlit_image_select import image_select
|
25 |
from streamlit_js_eval import streamlit_js_eval
|
|
|
26 |
from modules.toWizard import create_wizard_file
|
27 |
from huggingface_hub import hf_hub_download
|
28 |
import time
|
|
|
29 |
from modules.toXML import get_size_elements
|
30 |
|
31 |
+
# Function to get memory usage
|
32 |
def get_memory_usage():
|
33 |
+
"""
|
34 |
+
Returns the current memory usage of the process in MB.
|
35 |
+
"""
|
36 |
process = psutil.Process()
|
37 |
mem_info = process.memory_info()
|
38 |
return mem_info.rss / (1024 ** 2) # Return memory usage in MB
|
39 |
|
40 |
+
# Function to clear memory
|
41 |
def clear_memory():
|
42 |
+
"""
|
43 |
+
Clears the Streamlit session state and triggers garbage collection.
|
44 |
+
"""
|
45 |
st.session_state.clear()
|
46 |
gc.collect()
|
47 |
|
|
|
48 |
# Function to read XML content from a file
|
49 |
def read_xml_file(filepath):
|
50 |
+
"""
|
51 |
+
Reads and returns the content of an XML file.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
- filepath (str): The path to the XML file.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
- str: The content of the XML file.
|
58 |
+
"""
|
59 |
with open(filepath, 'r', encoding='utf-8') as file:
|
60 |
return file.read()
|
61 |
|
|
|
62 |
# Suppress the symlink warning
|
63 |
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
|
64 |
|
65 |
# Function to load the models only once and use session state to keep track of it
|
66 |
def load_models():
|
67 |
+
"""
|
68 |
+
Loads the object and arrow detection models, either from the local file or
|
69 |
+
downloads from the Hugging Face Hub if not available locally. The models
|
70 |
+
are stored in the Streamlit session state.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
- model_object (torch.nn.Module): The loaded object detection model.
|
74 |
+
- model_arrow (torch.nn.Module): The loaded arrow detection model.
|
75 |
+
"""
|
76 |
with st.spinner('Loading model...'):
|
77 |
model_object = get_faster_rcnn_model(len(object_dict))
|
78 |
model_arrow = get_arrow_model(len(arrow_dict), 2)
|
|
|
89 |
|
90 |
# Load model arrow
|
91 |
if not Path(output_arrow).exists():
|
|
|
92 |
model_arrow.load_state_dict(torch.load(model_arrow_path, map_location=device))
|
93 |
st.session_state.model_arrow = model_arrow
|
94 |
print('Model arrow downloaded from Hugging Face Hub')
|
|
|
99 |
print()
|
100 |
st.session_state.model_arrow = model_arrow
|
101 |
print('Model arrow loaded from local file')
|
|
|
102 |
|
103 |
# Load model object
|
104 |
if not Path(output_object).exists():
|
|
|
105 |
model_object.load_state_dict(torch.load(model_object_path, map_location=device))
|
106 |
st.session_state.model_object = model_object
|
107 |
print('Model object downloaded from Hugging Face Hub')
|
|
|
108 |
torch.save(model_object.state_dict(), output_object)
|
109 |
elif 'model_object' not in st.session_state and Path(output_object).exists():
|
110 |
model_object.load_state_dict(torch.load(output_object, map_location=device))
|
111 |
print()
|
112 |
st.session_state.model_object = model_object
|
113 |
+
print('Model object loaded from local file')
|
|
|
114 |
|
115 |
# Move models to device
|
116 |
model_arrow.to(device)
|
|
|
123 |
|
124 |
# Function to prepare the image for processing
|
125 |
def prepare_image(image, pad=True, new_size=(1333, 1333)):
|
126 |
+
"""
|
127 |
+
Resizes and optionally pads the input image to a new size.
|
128 |
+
|
129 |
+
Parameters:
|
130 |
+
- image (PIL.Image): The image to be processed.
|
131 |
+
- pad (bool): Whether to pad the image to the new size.
|
132 |
+
- new_size (tuple): The target size for the image.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
- PIL.Image: The processed image.
|
136 |
+
"""
|
137 |
original_size = image.size
|
138 |
# Calculate scale to fit the new size while maintaining aspect ratio
|
139 |
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
|
|
|
152 |
|
153 |
# Function to display various options for image annotation
|
154 |
def display_options(image, score_threshold, is_mobile, screen_width):
|
155 |
+
"""
|
156 |
+
Displays various options for image annotation and draws the annotated image.
|
157 |
+
|
158 |
+
Parameters:
|
159 |
+
- image (PIL.Image): The image to be annotated.
|
160 |
+
- score_threshold (float): The score threshold for displaying annotations.
|
161 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
162 |
+
- screen_width (int): The width of the screen.
|
163 |
+
"""
|
164 |
col1, col2, col3, col4, col5 = st.columns(5)
|
165 |
with col1:
|
166 |
write_class = st.toggle("Write Class", value=True)
|
|
|
190 |
if is_mobile is True:
|
191 |
width = screen_width
|
192 |
else:
|
193 |
+
width = screen_width // 2
|
194 |
|
195 |
# Display the original and annotated images side by side
|
196 |
image_comparison(
|
|
|
204 |
|
205 |
# Function to perform inference on the uploaded image using the loaded models
|
206 |
def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
|
207 |
+
"""
|
208 |
+
Performs inference on the uploaded image using the loaded models and updates
|
209 |
+
the session state with predictions and text mappings.
|
210 |
+
|
211 |
+
Parameters:
|
212 |
+
- model_object (torch.nn.Module): The object detection model.
|
213 |
+
- model_arrow (torch.nn.Module): The arrow detection model.
|
214 |
+
- image (PIL.Image): The uploaded image.
|
215 |
+
- score_threshold (float): The score threshold for displaying annotations.
|
216 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
217 |
+
- screen_width (int): The width of the screen.
|
218 |
+
- iou_threshold (float): The IoU threshold for filtering boxes.
|
219 |
+
- distance_treshold (int): The distance threshold for matching keypoints.
|
220 |
+
- percentage_text_dist_thresh (float): The percentage distance threshold for text mapping.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
- tuple: The processed image, prediction, and text mapping.
|
224 |
+
"""
|
225 |
uploaded_image = prepare_image(image, pad=False)
|
|
|
226 |
img_tensor = F.to_tensor(prepare_image(image.convert('RGB')))
|
227 |
|
228 |
# Display original image
|
|
|
231 |
if is_mobile is False:
|
232 |
width = screen_width
|
233 |
if is_mobile is False:
|
234 |
+
width = screen_width // 2
|
235 |
image_placeholder.image(uploaded_image, caption='Original Image', width=width)
|
236 |
|
237 |
# Perform OCR on the uploaded image
|
|
|
243 |
# Prediction
|
244 |
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
|
245 |
|
246 |
+
# Mapping text to prediction
|
247 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
248 |
+
|
249 |
# Remove the original image display
|
250 |
image_placeholder.empty()
|
251 |
|
|
|
254 |
|
255 |
return image, st.session_state.prediction, st.session_state.text_mapping
|
256 |
|
257 |
+
# Function to get the image from the uploaded file
|
258 |
@st.cache_data
|
259 |
def get_image(uploaded_file):
|
260 |
+
"""
|
261 |
+
Opens and converts the uploaded image file to RGB format.
|
262 |
+
|
263 |
+
Parameters:
|
264 |
+
- uploaded_file: The uploaded image file.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
- PIL.Image: The opened and converted image.
|
268 |
+
"""
|
269 |
return Image.open(uploaded_file).convert('RGB')
|
270 |
|
271 |
+
# Function to configure the Streamlit page
|
272 |
def configure_page():
|
273 |
+
"""
|
274 |
+
Configures the Streamlit page layout and returns the screen width
|
275 |
+
and a flag indicating if the device is mobile.
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
279 |
+
- screen_width (int): The width of the screen.
|
280 |
+
"""
|
281 |
st.set_page_config(layout="wide")
|
282 |
screen_width = streamlit_js_eval(js_expressions='screen.width', want_output=True, key='SCR')
|
283 |
is_mobile = screen_width is not None and screen_width < 800
|
284 |
return is_mobile, screen_width
|
285 |
|
286 |
+
# Function to display the banner based on device type and theme
|
287 |
def display_banner(is_mobile):
|
|
|
|
|
|
|
288 |
"""
|
289 |
+
Displays the appropriate banner image based on device type and dark mode preference.
|
290 |
|
291 |
+
Parameters:
|
292 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
293 |
+
"""
|
294 |
+
dark_mode_js = "(window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches)"
|
295 |
is_dark_mode = streamlit_js_eval(js_expressions=dark_mode_js, key='dark_mode')
|
296 |
|
297 |
if is_mobile:
|
|
|
305 |
else:
|
306 |
st.image("./images/banner_desktop.png", use_column_width=True)
|
307 |
|
308 |
+
# Function to display the title based on device type
|
309 |
def display_title(is_mobile):
|
310 |
+
"""
|
311 |
+
Displays the title of the app based on device type.
|
312 |
+
|
313 |
+
Parameters:
|
314 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
315 |
+
"""
|
316 |
title = "Welcome on the BPMN AI model recognition app"
|
317 |
if is_mobile:
|
318 |
title = "Welcome on the mobile version of BPMN AI model recognition app"
|
319 |
st.title(title)
|
320 |
|
321 |
+
# Function to display the sidebar with instructions and information
|
322 |
def display_sidebar():
|
323 |
+
"""
|
324 |
+
Displays the sidebar with instructions and information about the app.
|
325 |
+
"""
|
326 |
st.sidebar.header("This BPMN AI model recognition is proposed by: \n ELCA in collaboration with EPFL.")
|
327 |
st.sidebar.subheader("Instructions:")
|
328 |
+
st.sidebar.text("1. Upload your image")
|
329 |
st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
|
330 |
st.sidebar.text("3. Set the score threshold for\n prediction (default is 0.5)")
|
331 |
st.sidebar.text("4. Click on 'Launch Prediction'")
|
|
|
333 |
st.sidebar.text("6. You can modify the result \n by clicking on:\n 'Method&Style modification'")
|
334 |
st.sidebar.text("7. You can change the scale for \n the XML file and the size of \n elements (default is 1.0)")
|
335 |
st.sidebar.text("8. You can modify with modeler \n and download the result in \n right format")
|
|
|
336 |
st.sidebar.subheader("If there is an error, try to:")
|
337 |
st.sidebar.text("1. Change the score threshold")
|
338 |
st.sidebar.text("2. Re-crop the image by placing\n the BPMN diagram in the\n center of the image")
|
339 |
st.sidebar.text("3. Re-Launch the prediction")
|
|
|
340 |
st.sidebar.subheader("You can close this sidebar")
|
|
|
341 |
for i in range(5):
|
342 |
st.sidebar.subheader("")
|
|
|
343 |
st.sidebar.subheader("Made with ❤️ by Benjamin.K")
|
344 |
|
345 |
+
# Function to initialize session state variables
|
346 |
def initialize_session_state():
|
347 |
+
"""
|
348 |
+
Initializes the session state variables for the app.
|
349 |
+
"""
|
350 |
if 'pool_bboxes' not in st.session_state:
|
351 |
st.session_state.pool_bboxes = []
|
352 |
if 'model_loaded' not in st.session_state:
|
|
|
356 |
load_models()
|
357 |
st.rerun()
|
358 |
|
359 |
+
# Function to load example images for testing
|
360 |
def load_example_image():
|
361 |
+
"""
|
362 |
+
Loads example images for testing the app and returns the selected image.
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
- str: The path to the selected example image.
|
366 |
+
"""
|
367 |
with st.expander("Use example images"):
|
368 |
img_selected = image_select(
|
369 |
"If you have no image and just want to test the demo, click on one of these images",
|
|
|
375 |
)
|
376 |
return img_selected
|
377 |
|
378 |
+
# Function to load user-uploaded images or selected example images
|
379 |
def load_user_image(img_selected, is_mobile):
|
380 |
+
"""
|
381 |
+
Loads the user-uploaded image or the selected example image.
|
382 |
+
|
383 |
+
Parameters:
|
384 |
+
- img_selected (str): The path to the selected example image.
|
385 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
386 |
+
|
387 |
+
Returns:
|
388 |
+
- str: The path to the uploaded image file.
|
389 |
+
"""
|
390 |
if img_selected == './images/none.jpg':
|
391 |
img_selected = None
|
|
|
392 |
if img_selected is not None:
|
393 |
uploaded_file = img_selected
|
394 |
else:
|
|
|
398 |
col1, col2 = st.columns(2)
|
399 |
with col1:
|
400 |
uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
|
|
|
401 |
return uploaded_file
|
402 |
|
403 |
+
# Function to display the uploaded or example image
|
404 |
def display_image(uploaded_file, screen_width, is_mobile):
|
405 |
+
"""
|
406 |
+
Displays the uploaded or selected example image with options to rotate and adjust brightness.
|
407 |
+
|
408 |
+
Parameters:
|
409 |
+
- uploaded_file: The uploaded image file.
|
410 |
+
- screen_width (int): The width of the screen.
|
411 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
412 |
+
|
413 |
+
Returns:
|
414 |
+
- PIL.Image: The cropped and adjusted image.
|
415 |
+
"""
|
416 |
if 'rotation_angle' not in st.session_state:
|
417 |
st.session_state.rotation_angle = 0 # Initialize the rotation angle in session state
|
|
|
418 |
if 'brightness' not in st.session_state:
|
419 |
st.session_state.brightness = 1.0 # Initialize brightness in session state
|
420 |
|
|
|
457 |
if not is_mobile:
|
458 |
cropped_image = crop_image(adjusted_image, original_image)
|
459 |
else:
|
460 |
+
st.image(adjusted_image, caption="Image", use_column_width=False, width=int(4 / 5 * screen_width))
|
461 |
cropped_image = original_image
|
462 |
|
463 |
return cropped_image
|
464 |
|
465 |
+
# Function to crop the image
|
|
|
|
|
466 |
def crop_image(resized_image, original_image):
|
467 |
+
"""
|
468 |
+
Crops the resized image based on user input.
|
469 |
+
|
470 |
+
Parameters:
|
471 |
+
- resized_image (PIL.Image): The resized image.
|
472 |
+
- original_image (PIL.Image): The original image.
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
- PIL.Image: The cropped image.
|
476 |
+
"""
|
477 |
marge = 10
|
478 |
cropped_box = st_cropper(
|
479 |
resized_image,
|
|
|
489 |
cropped_image = original_image.crop((x0, y0, x1, y1))
|
490 |
return cropped_image
|
491 |
|
492 |
+
# Function to get the score threshold for prediction
|
493 |
def get_score_threshold(is_mobile):
|
494 |
+
"""
|
495 |
+
Displays a slider to set the score threshold for prediction.
|
496 |
+
|
497 |
+
Parameters:
|
498 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
499 |
+
"""
|
500 |
col1, col2 = st.columns(2)
|
501 |
with col1:
|
502 |
+
st.session_state.score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
|
503 |
|
504 |
def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
|
505 |
+
"""
|
506 |
+
Launches the prediction process on the cropped image and displays balloons upon completion.
|
507 |
+
|
508 |
+
Parameters:
|
509 |
+
- cropped_image (PIL.Image): The cropped image to be processed.
|
510 |
+
- score_threshold (float): The score threshold for predictions.
|
511 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
512 |
+
- screen_width (int): The width of the screen.
|
513 |
+
|
514 |
+
Returns:
|
515 |
+
- PIL.Image: The image after performing inference.
|
516 |
+
"""
|
517 |
st.session_state.crop_image = cropped_image
|
518 |
with st.spinner('Processing...'):
|
519 |
+
image, _, _ = perform_inference(
|
520 |
st.session_state.model_object, st.session_state.model_arrow, st.session_state.crop_image,
|
521 |
score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
|
522 |
)
|
523 |
st.balloons()
|
524 |
return image
|
|
|
525 |
|
526 |
def modify_results(percentage_text_dist_thresh=0.5):
|
527 |
+
"""
|
528 |
+
Allows the user to modify the results using method and style modification.
|
529 |
+
|
530 |
+
Parameters:
|
531 |
+
- percentage_text_dist_thresh (float): Threshold for mapping text to predictions based on percentage distance.
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
- bool: True if changes are detected and modifications are made, otherwise False.
|
535 |
+
"""
|
536 |
with st.expander("Method & Style modification"):
|
537 |
label_list = list(object_dict.values())
|
538 |
if st.session_state.prediction['labels'][-1] == 6:
|
|
|
588 |
|
589 |
object_labels = np.array(object_labels)
|
590 |
|
|
|
591 |
if len(object_bboxes) == len(bboxes):
|
592 |
# Calculate absolute differences
|
593 |
abs_diff = np.abs(object_bboxes - bboxes)
|
|
|
598 |
changes = True
|
599 |
break
|
600 |
|
601 |
+
# Check if labels are the same
|
602 |
if not np.array_equal(object_labels, new_lab):
|
603 |
changes = True
|
604 |
else:
|
|
|
619 |
new_scores = np.concatenate((object_scores, arrow_score))
|
620 |
new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
|
621 |
|
|
|
622 |
boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict)
|
623 |
|
624 |
st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
|
|
|
630 |
|
631 |
return True
|
632 |
|
|
|
|
|
|
|
633 |
def display_bpmn_modeler(is_mobile, screen_width):
|
634 |
+
"""
|
635 |
+
Displays the BPMN modeler with the current prediction and text mapping.
|
636 |
+
|
637 |
+
Parameters:
|
638 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
639 |
+
- screen_width (int): The width of the screen.
|
640 |
+
"""
|
641 |
with st.spinner('Waiting for BPMN modeler...'):
|
642 |
st.session_state.bpmn_xml = create_XML(
|
643 |
st.session_state.prediction.copy(), st.session_state.text_mapping,
|
644 |
st.session_state.size_scale, st.session_state.scale
|
645 |
)
|
|
|
646 |
|
647 |
+
st.session_state.vizi_file = create_wizard_file(st.session_state.prediction.copy(), st.session_state.text_mapping)
|
648 |
+
|
649 |
display_bpmn_xml(st.session_state.bpmn_xml, st.session_state.vizi_file, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
|
650 |
|
|
|
651 |
def find_best_scale(pred, size_elements):
|
652 |
+
"""
|
653 |
+
Finds the best scale for the elements in the prediction.
|
654 |
+
|
655 |
+
Parameters:
|
656 |
+
- pred (dict): The prediction data.
|
657 |
+
- size_elements (dict): The size elements dictionary.
|
658 |
+
|
659 |
+
Returns:
|
660 |
+
- float: The best scale for the elements.
|
661 |
+
"""
|
662 |
boxes = pred['boxes']
|
663 |
labels = pred['labels']
|
664 |
|
|
|
690 |
return best_scale
|
691 |
|
692 |
def modeler_options(is_mobile):
|
693 |
+
"""
|
694 |
+
Displays options for the BPMN modeler.
|
695 |
+
|
696 |
+
Parameters:
|
697 |
+
- is_mobile (bool): Flag indicating if the device is mobile.
|
698 |
+
"""
|
699 |
if not is_mobile:
|
700 |
with st.expander("Options for BPMN modeler"):
|
701 |
col1, col2 = st.columns(2)
|
|
|
706 |
st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
|
707 |
else:
|
708 |
st.session_state.scale = 1.0
|
709 |
+
st.session_state.size_scale = 1.0
|
modules/toWizard.py
CHANGED
@@ -4,13 +4,31 @@ from xml.dom import minidom
|
|
4 |
from modules.utils import error
|
5 |
from modules.OCR import analyze_sentiment
|
6 |
|
7 |
-
|
8 |
def rescale(scale, boxes):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
for i in range(len(boxes)):
|
10 |
boxes[i] = [boxes[i][0] * scale, boxes[i][1] * scale, boxes[i][2] * scale, boxes[i][3] * scale]
|
11 |
return boxes
|
12 |
|
13 |
def create_BPMN_id(data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
enum_end, enum_start, enum_task, enum_sequence, enum_dataflow, enum_messflow, enum_messageEvent, enum_exclusiveGateway, enum_parallelGateway, enum_pool = 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
|
15 |
BPMN_name = [class_dict[data['labels'][i]] for i in range(len(data['labels']))]
|
16 |
for idx, Bpmn_id in enumerate(BPMN_name):
|
@@ -49,15 +67,35 @@ def create_BPMN_id(data):
|
|
49 |
return data
|
50 |
|
51 |
def check_end(link):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
if link[1] is None:
|
53 |
return True
|
54 |
return False
|
55 |
|
56 |
def connect(data, text_mapping, i):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
next_text = []
|
58 |
target_idx = data['links'][i][1]
|
59 |
# Check if the target index is valid
|
60 |
-
if target_idx==None or target_idx >= len(data['links']):
|
61 |
error('There may be an error with the Vizi file, care when you download it.')
|
62 |
return None, None, None
|
63 |
|
@@ -80,11 +118,30 @@ def connect(data, text_mapping, i):
|
|
80 |
return current_text, next_text, next_id
|
81 |
|
82 |
def check_start(val):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
if val[0] is None:
|
84 |
return True
|
85 |
return False
|
86 |
|
87 |
def find_merge(bpmn_id, links):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
merge = []
|
89 |
for idx, link in enumerate(links):
|
90 |
next_element = link[1]
|
@@ -104,7 +161,7 @@ def find_merge(bpmn_id, links):
|
|
104 |
if element is None:
|
105 |
merge_elements[idx] = False
|
106 |
continue
|
107 |
-
#
|
108 |
count = merge.count(element)
|
109 |
if count > 1:
|
110 |
merge_elements[idx] = True
|
@@ -114,6 +171,17 @@ def find_merge(bpmn_id, links):
|
|
114 |
return merge_elements
|
115 |
|
116 |
def find_positive_end(bpmn_ids, links, text_mapping):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
emotion_data = []
|
118 |
for idx, bpmn_id in enumerate(bpmn_ids):
|
119 |
if idx >= len(links):
|
@@ -130,6 +198,15 @@ def find_positive_end(bpmn_ids, links, text_mapping):
|
|
130 |
return sorted_emotions[0][0] if len(sorted_emotions) > 0 else None
|
131 |
|
132 |
def find_best_direction(texts_list):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
emotion_data = []
|
134 |
for text in texts_list:
|
135 |
highest_emotion, highest_score = analyze_sentiment(text)
|
@@ -141,18 +218,24 @@ def find_best_direction(texts_list):
|
|
141 |
|
142 |
return sorted_emotions[0][0] if len(sorted_emotions) > 0 else None
|
143 |
|
144 |
-
|
145 |
-
|
146 |
def create_wizard_file(data, text_mapping):
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
|
|
|
|
|
|
148 |
not_change = ['pool','sequenceFlow','messageFlow','dataAssociation']
|
149 |
|
150 |
-
#
|
151 |
for idx, key in enumerate(text_mapping.keys()):
|
152 |
if text_mapping[key] == '' and key.split('_')[0] not in not_change:
|
153 |
text_mapping[key] = f'unnamed_{key}'
|
154 |
|
155 |
-
|
156 |
root = ET.Element('methodAndStyleWizard')
|
157 |
|
158 |
modelName = ET.SubElement(root, 'modelName')
|
@@ -179,7 +262,7 @@ def create_wizard_file(data, text_mapping):
|
|
179 |
eventType = 'None'
|
180 |
if idx >= len(data['links']):
|
181 |
continue
|
182 |
-
if check_start(data['links'][idx]) and (element_type=='event' or element_type=='message'):
|
183 |
if text_mapping[Bpmn_id] == '':
|
184 |
text_mapping[Bpmn_id] = 'start'
|
185 |
startEvent = ET.SubElement(root, 'startEvent', attrib={'name': text_mapping[Bpmn_id], 'eventType': eventType, 'isRegular': 'True'})
|
@@ -191,8 +274,7 @@ def create_wizard_file(data, text_mapping):
|
|
191 |
|
192 |
positive_end = find_positive_end(data['BPMN_id'], data['links'], text_mapping)
|
193 |
if positive_end is not None:
|
194 |
-
print("Best end is: ",text_mapping[positive_end])
|
195 |
-
|
196 |
|
197 |
# Add end states event to the collaboration element
|
198 |
for idx, Bpmn_id in enumerate(data['BPMN_id']):
|
@@ -208,7 +290,6 @@ def create_wizard_file(data, text_mapping):
|
|
208 |
else:
|
209 |
ET.SubElement(endEvents, 'endState', attrib={'name': text_mapping[Bpmn_id], 'eventType': 'None', 'isRegular': 'False'})
|
210 |
|
211 |
-
|
212 |
# Add activities to the collaboration element
|
213 |
activities = ET.SubElement(root, 'activities')
|
214 |
for idx, activity_name in enumerate(data['BPMN_id']):
|
@@ -269,7 +350,7 @@ def create_wizard_file(data, text_mapping):
|
|
269 |
ET.SubElement(root, 'participants')
|
270 |
|
271 |
# Pretty print the XML
|
272 |
-
|
273 |
-
|
274 |
|
275 |
-
return
|
|
|
4 |
from modules.utils import error
|
5 |
from modules.OCR import analyze_sentiment
|
6 |
|
|
|
7 |
def rescale(scale, boxes):
|
8 |
+
"""
|
9 |
+
Rescale the coordinates of the bounding boxes by a given scale factor.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
scale (float): The scale factor to apply.
|
13 |
+
boxes (list): List of bounding boxes to be rescaled.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
list: Rescaled bounding boxes.
|
17 |
+
"""
|
18 |
for i in range(len(boxes)):
|
19 |
boxes[i] = [boxes[i][0] * scale, boxes[i][1] * scale, boxes[i][2] * scale, boxes[i][3] * scale]
|
20 |
return boxes
|
21 |
|
22 |
def create_BPMN_id(data):
|
23 |
+
"""
|
24 |
+
Create unique BPMN IDs for each element in the data based on their types.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
data (dict): Dictionary containing labels and links of elements.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
dict: Updated data with BPMN IDs assigned.
|
31 |
+
"""
|
32 |
enum_end, enum_start, enum_task, enum_sequence, enum_dataflow, enum_messflow, enum_messageEvent, enum_exclusiveGateway, enum_parallelGateway, enum_pool = 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
|
33 |
BPMN_name = [class_dict[data['labels'][i]] for i in range(len(data['labels']))]
|
34 |
for idx, Bpmn_id in enumerate(BPMN_name):
|
|
|
67 |
return data
|
68 |
|
69 |
def check_end(link):
|
70 |
+
"""
|
71 |
+
Check if a link represents an end event.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
link (tuple): A link containing indices of connected elements.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
bool: True if the link represents an end event, False otherwise.
|
78 |
+
"""
|
79 |
if link[1] is None:
|
80 |
return True
|
81 |
return False
|
82 |
|
83 |
def connect(data, text_mapping, i):
|
84 |
+
"""
|
85 |
+
Connect elements based on their links and generate the corresponding text mapping.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
data (dict): Data containing links and BPMN IDs.
|
89 |
+
text_mapping (dict): Mapping of BPMN IDs to their text descriptions.
|
90 |
+
i (int): Index of the current element.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
tuple: Current text, next texts, and next ID.
|
94 |
+
"""
|
95 |
next_text = []
|
96 |
target_idx = data['links'][i][1]
|
97 |
# Check if the target index is valid
|
98 |
+
if target_idx == None or target_idx >= len(data['links']):
|
99 |
error('There may be an error with the Vizi file, care when you download it.')
|
100 |
return None, None, None
|
101 |
|
|
|
118 |
return current_text, next_text, next_id
|
119 |
|
120 |
def check_start(val):
|
121 |
+
"""
|
122 |
+
Check if a link represents a start event.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
val (tuple): A link containing indices of connected elements.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
bool: True if the link represents a start event, False otherwise.
|
129 |
+
"""
|
130 |
if val[0] is None:
|
131 |
return True
|
132 |
return False
|
133 |
|
134 |
def find_merge(bpmn_id, links):
|
135 |
+
"""
|
136 |
+
Identify merge points in the BPMN diagram.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
bpmn_id (list): List of BPMN IDs.
|
140 |
+
links (list): List of links between elements.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
list: List indicating merge points.
|
144 |
+
"""
|
145 |
merge = []
|
146 |
for idx, link in enumerate(links):
|
147 |
next_element = link[1]
|
|
|
161 |
if element is None:
|
162 |
merge_elements[idx] = False
|
163 |
continue
|
164 |
+
# Count how many times the element is in the list
|
165 |
count = merge.count(element)
|
166 |
if count > 1:
|
167 |
merge_elements[idx] = True
|
|
|
171 |
return merge_elements
|
172 |
|
173 |
def find_positive_end(bpmn_ids, links, text_mapping):
|
174 |
+
"""
|
175 |
+
Find the positive end event based on sentiment analysis.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
bpmn_ids (list): List of BPMN IDs.
|
179 |
+
links (list): List of links between elements.
|
180 |
+
text_mapping (dict): Mapping of BPMN IDs to their text descriptions.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
str: BPMN ID of the positive end event.
|
184 |
+
"""
|
185 |
emotion_data = []
|
186 |
for idx, bpmn_id in enumerate(bpmn_ids):
|
187 |
if idx >= len(links):
|
|
|
198 |
return sorted_emotions[0][0] if len(sorted_emotions) > 0 else None
|
199 |
|
200 |
def find_best_direction(texts_list):
|
201 |
+
"""
|
202 |
+
Find the best direction based on sentiment analysis.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
texts_list (list): List of texts to analyze.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
str: Text with the best (positive) sentiment.
|
209 |
+
"""
|
210 |
emotion_data = []
|
211 |
for text in texts_list:
|
212 |
highest_emotion, highest_score = analyze_sentiment(text)
|
|
|
218 |
|
219 |
return sorted_emotions[0][0] if len(sorted_emotions) > 0 else None
|
220 |
|
|
|
|
|
221 |
def create_wizard_file(data, text_mapping):
|
222 |
+
"""
|
223 |
+
Create a wizard file for BPMN modeling based on the provided data and text mappings.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
data (dict): Data containing BPMN elements and their properties.
|
227 |
+
text_mapping (dict): Mapping of BPMN IDs to their text descriptions.
|
228 |
|
229 |
+
Returns:
|
230 |
+
str: Pretty-printed XML string of the wizard file.
|
231 |
+
"""
|
232 |
not_change = ['pool','sequenceFlow','messageFlow','dataAssociation']
|
233 |
|
234 |
+
# Add a name into the text_mapping when there is no name
|
235 |
for idx, key in enumerate(text_mapping.keys()):
|
236 |
if text_mapping[key] == '' and key.split('_')[0] not in not_change:
|
237 |
text_mapping[key] = f'unnamed_{key}'
|
238 |
|
|
|
239 |
root = ET.Element('methodAndStyleWizard')
|
240 |
|
241 |
modelName = ET.SubElement(root, 'modelName')
|
|
|
262 |
eventType = 'None'
|
263 |
if idx >= len(data['links']):
|
264 |
continue
|
265 |
+
if check_start(data['links'][idx]) and (element_type == 'event' or element_type == 'message'):
|
266 |
if text_mapping[Bpmn_id] == '':
|
267 |
text_mapping[Bpmn_id] = 'start'
|
268 |
startEvent = ET.SubElement(root, 'startEvent', attrib={'name': text_mapping[Bpmn_id], 'eventType': eventType, 'isRegular': 'True'})
|
|
|
274 |
|
275 |
positive_end = find_positive_end(data['BPMN_id'], data['links'], text_mapping)
|
276 |
if positive_end is not None:
|
277 |
+
print("Best end is: ", text_mapping[positive_end])
|
|
|
278 |
|
279 |
# Add end states event to the collaboration element
|
280 |
for idx, Bpmn_id in enumerate(data['BPMN_id']):
|
|
|
290 |
else:
|
291 |
ET.SubElement(endEvents, 'endState', attrib={'name': text_mapping[Bpmn_id], 'eventType': 'None', 'isRegular': 'False'})
|
292 |
|
|
|
293 |
# Add activities to the collaboration element
|
294 |
activities = ET.SubElement(root, 'activities')
|
295 |
for idx, activity_name in enumerate(data['BPMN_id']):
|
|
|
350 |
ET.SubElement(root, 'participants')
|
351 |
|
352 |
# Pretty print the XML
|
353 |
+
pwm_str = ET.tostring(root, encoding='utf-8', method='xml')
|
354 |
+
pretty_pwm_str = minidom.parseString(pwm_str).toprettyxml(indent=" ")
|
355 |
|
356 |
+
return pretty_pwm_str
|
modules/toXML.py
CHANGED
@@ -7,7 +7,16 @@ from xml.dom import minidom
|
|
7 |
import numpy as np
|
8 |
|
9 |
def find_position(pool_index, BPMN_id):
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
if pool_index in BPMN_id:
|
12 |
position = BPMN_id.index(pool_index)
|
13 |
else:
|
@@ -18,6 +27,16 @@ def find_position(pool_index, BPMN_id):
|
|
18 |
|
19 |
# Calculate the center of each bounding box and group them by pool
|
20 |
def calculate_centers_and_group_by_pool(pred, class_dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
pool_groups = {}
|
22 |
for pool_index, element_indices in pred['pool_dict'].items():
|
23 |
pool_groups[pool_index] = []
|
@@ -26,12 +45,23 @@ def calculate_centers_and_group_by_pool(pred, class_dict):
|
|
26 |
continue
|
27 |
if class_dict[pred['labels'][i]] not in ['dataObject', 'dataStore']:
|
28 |
x1, y1, x2, y2 = pred['boxes'][i]
|
29 |
-
center = [(x1 + x2) / 2, (y1 + y2) / 2]
|
30 |
pool_groups[pool_index].append((center, i))
|
31 |
return pool_groups
|
32 |
|
33 |
# Group centers within a specified range
|
34 |
def group_centers(centers, axis, range_=50):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
groups = []
|
36 |
while centers:
|
37 |
center, idx = centers.pop(0)
|
@@ -45,18 +75,38 @@ def group_centers(centers, axis, range_=50):
|
|
45 |
|
46 |
# Align the elements within each pool
|
47 |
def align_elements_within_pool(modified_pred, pool_groups, class_dict, size):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
for pool_index, centers in pool_groups.items():
|
|
|
49 |
y_groups = group_centers(centers.copy(), axis=1)
|
50 |
align_y_coordinates(modified_pred, y_groups, class_dict, size)
|
51 |
|
|
|
52 |
centers = recalculate_centers(modified_pred, y_groups)
|
53 |
x_groups = group_centers(centers.copy(), axis=0)
|
54 |
align_x_coordinates(modified_pred, x_groups, class_dict, size)
|
55 |
|
56 |
# Align the y-coordinates of the centers of grouped bounding boxes
|
57 |
def align_y_coordinates(modified_pred, y_groups, class_dict, size):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
for group in y_groups:
|
59 |
-
avg_y = sum([c[0][1] for c in group]) / len(group)
|
60 |
for (center, idx) in group:
|
61 |
label = class_dict[modified_pred['labels'][idx]]
|
62 |
if label in size:
|
@@ -70,18 +120,37 @@ def align_y_coordinates(modified_pred, y_groups, class_dict, size):
|
|
70 |
|
71 |
# Recalculate centers after alignment
|
72 |
def recalculate_centers(modified_pred, groups):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
centers = []
|
74 |
for group in groups:
|
75 |
for center, idx in group:
|
76 |
x1, y1, x2, y2 = modified_pred['boxes'][idx]
|
77 |
-
center = [(x1 + x2) / 2, (y1 + y2) / 2]
|
78 |
centers.append((center, idx))
|
79 |
return centers
|
80 |
|
81 |
# Align the x-coordinates of the centers of grouped bounding boxes
|
82 |
def align_x_coordinates(modified_pred, x_groups, class_dict, size):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
for group in x_groups:
|
84 |
-
avg_x = sum([c[0][0] for c in group]) / len(group)
|
85 |
for (center, idx) in group:
|
86 |
label = class_dict[modified_pred['labels'][idx]]
|
87 |
if label in size:
|
@@ -95,6 +164,13 @@ def align_x_coordinates(modified_pred, x_groups, class_dict, size):
|
|
95 |
|
96 |
# Expand the pool bounding boxes to fit the aligned elements
|
97 |
def expand_pool_bounding_boxes(modified_pred, size_elements):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
for idx, (pool_index, keep_elements) in enumerate(modified_pred['pool_dict'].items()):
|
99 |
if len(keep_elements) != 0:
|
100 |
marge = size_elements['task'][1] // 2
|
@@ -114,10 +190,18 @@ def expand_pool_bounding_boxes(modified_pred, size_elements):
|
|
114 |
error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
|
115 |
continue
|
116 |
|
|
|
117 |
modified_pred['boxes'][position] = [min_x - marge, min_y - marge//2, min_x + pool_width + marge, min_y + pool_height + marge//2]
|
118 |
|
119 |
# Adjust left and right boundaries of all pools
|
120 |
def adjust_pool_boundaries(modified_pred, pred):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
min_left, max_right = 0, 0
|
122 |
for pool_index, element_indices in pred['pool_dict'].items():
|
123 |
position = find_position(pool_index, modified_pred['BPMN_id'])
|
@@ -140,10 +224,22 @@ def adjust_pool_boundaries(modified_pred, pred):
|
|
140 |
x1 = min_left
|
141 |
if x2 < max_right:
|
142 |
x2 = max_right
|
|
|
143 |
modified_pred['boxes'][position] = [x1, y1, x2, y2]
|
144 |
|
145 |
# Main function to align boxes
|
146 |
def align_boxes(pred, size, class_dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
modified_pred = copy.deepcopy(pred)
|
148 |
pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
|
149 |
align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
|
@@ -154,9 +250,20 @@ def align_boxes(pred, size, class_dict):
|
|
154 |
|
155 |
return modified_pred['boxes']
|
156 |
|
157 |
-
|
158 |
# Function to create a BPMN XML file from prediction results
|
159 |
def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
namespaces = {
|
161 |
'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
|
162 |
'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
|
@@ -165,7 +272,6 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
165 |
'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
|
166 |
}
|
167 |
|
168 |
-
|
169 |
definitions = ET.Element('bpmn:definitions', {
|
170 |
'xmlns:xsi': namespaces['xsi'],
|
171 |
'xmlns:bpmn': namespaces['bpmn'],
|
@@ -176,14 +282,13 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
176 |
'id': "simpleExample"
|
177 |
})
|
178 |
|
179 |
-
|
180 |
size_elements = get_size_elements(size_scale)
|
181 |
|
182 |
-
#
|
183 |
if len(full_pred['pool_dict']) == 0 or (len(full_pred['pool_dict']) == 1 and len(next(iter(full_pred['pool_dict'].values()))) == len(full_pred['labels'])):
|
184 |
full_pred, text_mapping = create_big_pool(full_pred, text_mapping, size_elements)
|
185 |
|
186 |
-
#
|
187 |
old_boxes = copy.deepcopy(full_pred)
|
188 |
|
189 |
# Create BPMN collaboration element
|
@@ -191,16 +296,16 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
191 |
|
192 |
# Create BPMN process elements
|
193 |
process = []
|
194 |
-
for idx in range
|
195 |
-
process_id = f'process_{idx+1}'
|
196 |
process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false'))
|
197 |
|
198 |
bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
|
199 |
bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
|
200 |
|
|
|
201 |
full_pred['boxes'] = rescale_boxes(scale, old_boxes['boxes'])
|
202 |
full_pred['boxes'] = align_boxes(full_pred, size_elements, class_dict)
|
203 |
-
|
204 |
|
205 |
# Add diagram elements for each pool
|
206 |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
@@ -208,8 +313,6 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
208 |
pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[pool_index])
|
209 |
|
210 |
position = find_position(pool_index, full_pred['BPMN_id'])
|
211 |
-
# Calculate the bounding box for the pool
|
212 |
-
#if len(keep_elements) == 0:
|
213 |
if position >= len(full_pred['boxes']):
|
214 |
print("Problem with the index")
|
215 |
continue
|
@@ -219,7 +322,6 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
219 |
|
220 |
add_diagram_elements(bpmnplane, pool_id, min_x, min_y, pool_width, pool_height)
|
221 |
|
222 |
-
|
223 |
# Create BPMN elements for each pool
|
224 |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
225 |
create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
|
@@ -244,6 +346,7 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
244 |
reparsed = minidom.parseString(rough_string)
|
245 |
pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
|
246 |
|
|
|
247 |
full_pred['boxes'] = rescale_boxes(1/scale, full_pred['boxes'])
|
248 |
full_pred['boxes'] = old_boxes
|
249 |
|
@@ -251,11 +354,22 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
251 |
|
252 |
# Function that creates a single pool with all elements
|
253 |
def create_big_pool(full_pred, text_mapping, size_elements, marge=50):
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
new_pool_index = 'pool_1'
|
256 |
size_elements = get_size_elements(st.session_state.size_scale)
|
257 |
elements_pool = list(range(len(full_pred['boxes'])))
|
258 |
-
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred['boxes'],full_pred['labels'], elements_pool, size_elements)
|
259 |
box = [min_x - marge, min_y - marge//2, max_x + marge, max_y + marge//2]
|
260 |
full_pred['boxes'] = np.append(full_pred['boxes'], [box], axis=0)
|
261 |
full_pred['pool_dict'][new_pool_index] = elements_pool
|
@@ -266,33 +380,61 @@ def create_big_pool(full_pred, text_mapping, size_elements, marge=50):
|
|
266 |
|
267 |
# Function that gives the size of the elements
|
268 |
def get_size_elements(size_scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
size_elements = {
|
270 |
-
'event': (size_scale*43.2, size_scale*43.2),
|
271 |
-
'task': (size_scale*120, size_scale*96),
|
272 |
-
'message': (size_scale*43.2, size_scale*43.2),
|
273 |
-
'messageEvent': (size_scale*43.2, size_scale*43.2),
|
274 |
-
'exclusiveGateway': (size_scale*60, size_scale*60),
|
275 |
-
'parallelGateway': (size_scale*60, size_scale*60),
|
276 |
-
'dataObject': (size_scale*48, size_scale*72),
|
277 |
-
'dataStore': (size_scale*72, size_scale*72),
|
278 |
-
'subProcess': (size_scale*144, size_scale*108),
|
279 |
-
'eventBasedGateway': (size_scale*60, size_scale*60),
|
280 |
-
'timerEvent': (size_scale*48, size_scale*48),
|
281 |
}
|
282 |
return size_elements
|
283 |
|
284 |
def rescale(scale, boxes):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
for i in range(len(boxes)):
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
return boxes
|
291 |
|
292 |
-
#Function to create the unique BPMN_id
|
293 |
-
def create_BPMN_id(labels,pool_dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
-
|
|
|
|
|
|
|
296 |
|
297 |
data_counter = 1
|
298 |
|
@@ -336,7 +478,7 @@ def create_BPMN_id(labels,pool_dict):
|
|
336 |
else:
|
337 |
BPMN_id[idx] = f'{key}_{enums[key]}'
|
338 |
enums[key] += 1
|
339 |
-
|
340 |
# Update the pool_dict keys with their corresponding BPMN_id values
|
341 |
updated_pool_dict = {}
|
342 |
for key, value in pool_dict.items():
|
@@ -346,10 +488,18 @@ def create_BPMN_id(labels,pool_dict):
|
|
346 |
|
347 |
return BPMN_id, updated_pool_dict
|
348 |
|
349 |
-
|
350 |
-
|
351 |
def add_diagram_elements(parent, element_id, x, y, width, height):
|
352 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
shape = ET.SubElement(parent, 'bpmndi:BPMNShape', attrib={
|
354 |
'bpmnElement': element_id,
|
355 |
'id': element_id + '_di'
|
@@ -362,7 +512,14 @@ def add_diagram_elements(parent, element_id, x, y, width, height):
|
|
362 |
})
|
363 |
|
364 |
def add_diagram_edge(parent, element_id, waypoints):
|
365 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
edge = ET.SubElement(parent, 'bpmndi:BPMNEdge', attrib={
|
367 |
'bpmnElement': element_id,
|
368 |
'id': element_id + '_di'
|
@@ -375,8 +532,17 @@ def add_diagram_edge(parent, element_id, waypoints):
|
|
375 |
'y': str(y)
|
376 |
})
|
377 |
|
378 |
-
|
379 |
def check_status(link, keep_elements):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
if link[0] in keep_elements and link[1] in keep_elements:
|
381 |
return 'middle'
|
382 |
elif link[0] is None and link[1] in keep_elements:
|
@@ -385,40 +551,87 @@ def check_status(link, keep_elements):
|
|
385 |
return 'end'
|
386 |
else:
|
387 |
return 'middle'
|
388 |
-
|
389 |
def check_data_association(i, links, labels, keep_elements):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
status, links_idx = [], []
|
391 |
-
for j, (k,l) in enumerate(links):
|
392 |
if labels[j] == list(class_dict.values()).index('dataAssociation'):
|
393 |
-
if k==i:
|
394 |
status.append('output')
|
395 |
links_idx.append(j)
|
396 |
-
elif l==i:
|
397 |
status.append('input')
|
398 |
links_idx.append(j)
|
399 |
|
400 |
return status, links_idx
|
401 |
|
402 |
-
def create_data_Association(bpmn,data,size,element_id,current_idx,source_id,target_id):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
waypoints = calculate_waypoints(data, size, current_idx, source_id, target_id)
|
404 |
if waypoints is not None:
|
405 |
add_diagram_edge(bpmn, element_id, waypoints)
|
406 |
-
|
407 |
def check_eventBasedGateway(i, links, labels):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
status, links_idx = [], []
|
409 |
-
for j, (k,l) in enumerate(links):
|
410 |
if labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
411 |
-
if k==i:
|
412 |
status.append('output')
|
413 |
links_idx.append(j)
|
414 |
-
elif l==i:
|
415 |
status.append('input')
|
416 |
links_idx.append(j)
|
417 |
|
418 |
return status, links_idx
|
419 |
-
|
420 |
# Function to dynamically create and layout BPMN elements
|
421 |
def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data, keep_elements):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
elements = data['BPMN_id']
|
423 |
positions = data['boxes']
|
424 |
links = data['links']
|
@@ -536,7 +749,6 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
|
|
536 |
sub_element = ET.SubElement(element, 'bpmn:eventBasedGateway', id=f'eventBasedGateway_{link_idx}_{gateway_name.split("_")[1]}')
|
537 |
create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], i, element_id, gateway_name)
|
538 |
|
539 |
-
|
540 |
add_diagram_elements(bpmnplane, element_id, x, y, size['eventBasedGateway'][0], size['eventBasedGateway'][1])
|
541 |
|
542 |
# Data Object
|
@@ -558,6 +770,19 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
|
|
558 |
add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
|
559 |
|
560 |
def calculate_pool_bounds(boxes, labels, keep_elements, size=None, class_dict=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
561 |
min_x, min_y = float('inf'), float('inf')
|
562 |
max_x, max_y = float('-inf'), float('-inf')
|
563 |
|
@@ -588,9 +813,22 @@ def calculate_pool_bounds(boxes, labels, keep_elements, size=None, class_dict=No
|
|
588 |
|
589 |
return min_x, min_y, max_x, max_y
|
590 |
|
591 |
-
|
592 |
-
|
593 |
def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_element, target_element):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
# Get the bounding boxes of the source and target elements
|
595 |
source_box = data['boxes'][source_idx]
|
596 |
target_box = data['boxes'][target_idx]
|
@@ -625,11 +863,19 @@ def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_ele
|
|
625 |
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1])]
|
626 |
|
627 |
return waypoints
|
628 |
-
|
629 |
def add_curve(waypoints, pos_source, pos_target, threshold=30):
|
630 |
"""
|
631 |
Add a single curve to the sequence flow by introducing a control point.
|
632 |
The control point is added at an offset from the midpoint of the original waypoints.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
"""
|
634 |
if len(waypoints) < 2:
|
635 |
return waypoints
|
@@ -647,7 +893,7 @@ def add_curve(waypoints, pos_source, pos_target, threshold=30):
|
|
647 |
if abs(start_x - end_x) < threshold or abs(start_y - end_y) < threshold:
|
648 |
return waypoints
|
649 |
|
650 |
-
# Calculate the control point
|
651 |
if pos_source in pos_horizontal and pos_target in pos_horizontal:
|
652 |
control_point = None
|
653 |
elif pos_source in pos_vertical and pos_target in pos_vertical:
|
@@ -658,7 +904,6 @@ def add_curve(waypoints, pos_source, pos_target, threshold=30):
|
|
658 |
control_point = (start_x, end_y)
|
659 |
else:
|
660 |
control_point = None
|
661 |
-
|
662 |
|
663 |
# Create the curved path
|
664 |
if control_point is not None:
|
@@ -668,8 +913,20 @@ def add_curve(waypoints, pos_source, pos_target, threshold=30):
|
|
668 |
|
669 |
return curved_waypoints
|
670 |
|
671 |
-
|
672 |
def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
673 |
best_points = data['best_points'][current_idx]
|
674 |
pos_source = best_points[0]
|
675 |
pos_target = best_points[1]
|
@@ -684,7 +941,6 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
|
684 |
if source_idx is None or target_idx is None:
|
685 |
warning()
|
686 |
return None
|
687 |
-
|
688 |
|
689 |
name_source = source_id.split('_')[0]
|
690 |
name_target = target_id.split('_')[0]
|
@@ -702,6 +958,7 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
|
702 |
warning()
|
703 |
return [(source_x, source_y), (target_x, target_y)]
|
704 |
|
|
|
705 |
if pos_source == 'left':
|
706 |
source_x = source_x
|
707 |
source_y += size[name_source][1] / 2
|
@@ -715,6 +972,7 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
|
715 |
source_x += size[name_source][0] / 2
|
716 |
source_y += size[name_source][1]
|
717 |
|
|
|
718 |
if pos_target == 'left':
|
719 |
target_x = target_x
|
720 |
target_y += size[name_target][1] / 2
|
@@ -738,8 +996,19 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
|
738 |
|
739 |
return curved_waypoints
|
740 |
|
741 |
-
|
742 |
def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
source_idx, target_idx = data['links'][idx]
|
744 |
|
745 |
if source_idx is None or target_idx is None:
|
@@ -774,6 +1043,3 @@ def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=Fal
|
|
774 |
return
|
775 |
element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
776 |
add_diagram_edge(bpmn, element_id, waypoints)
|
777 |
-
|
778 |
-
|
779 |
-
|
|
|
7 |
import numpy as np
|
8 |
|
9 |
def find_position(pool_index, BPMN_id):
|
10 |
+
"""
|
11 |
+
Find the position of the pool index in the BPMN_id list.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
pool_index (str): The pool index to search for.
|
15 |
+
BPMN_id (list): List of BPMN IDs.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
int: The index of the pool_index in BPMN_id, or None if not found.
|
19 |
+
"""
|
20 |
if pool_index in BPMN_id:
|
21 |
position = BPMN_id.index(pool_index)
|
22 |
else:
|
|
|
27 |
|
28 |
# Calculate the center of each bounding box and group them by pool
|
29 |
def calculate_centers_and_group_by_pool(pred, class_dict):
|
30 |
+
"""
|
31 |
+
Calculate the center coordinates of bounding boxes and group them by pool.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
pred (dict): Dictionary containing prediction results, including 'pool_dict', 'boxes', and 'labels'.
|
35 |
+
class_dict (dict): Dictionary mapping class indices to class names.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
dict: Dictionary grouping centers and their indices by pool index.
|
39 |
+
"""
|
40 |
pool_groups = {}
|
41 |
for pool_index, element_indices in pred['pool_dict'].items():
|
42 |
pool_groups[pool_index] = []
|
|
|
45 |
continue
|
46 |
if class_dict[pred['labels'][i]] not in ['dataObject', 'dataStore']:
|
47 |
x1, y1, x2, y2 = pred['boxes'][i]
|
48 |
+
center = [(x1 + x2) / 2, (y1 + y2) / 2] # Compute the center of the bounding box
|
49 |
pool_groups[pool_index].append((center, i))
|
50 |
return pool_groups
|
51 |
|
52 |
# Group centers within a specified range
|
53 |
def group_centers(centers, axis, range_=50):
|
54 |
+
"""
|
55 |
+
Group centers based on a specified range along an axis.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
centers (list): List of center coordinates and their indices.
|
59 |
+
axis (int): The axis (0 for x, 1 for y) to group centers along.
|
60 |
+
range_ (int): Maximum distance to consider centers as part of the same group.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
list: List of groups, where each group is a list of centers and indices.
|
64 |
+
"""
|
65 |
groups = []
|
66 |
while centers:
|
67 |
center, idx = centers.pop(0)
|
|
|
75 |
|
76 |
# Align the elements within each pool
|
77 |
def align_elements_within_pool(modified_pred, pool_groups, class_dict, size):
|
78 |
+
"""
|
79 |
+
Align elements within each pool based on their centers.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
modified_pred (dict): Dictionary containing the modified predictions.
|
83 |
+
pool_groups (dict): Dictionary grouping centers and their indices by pool index.
|
84 |
+
class_dict (dict): Dictionary mapping class indices to class names.
|
85 |
+
size (dict): Dictionary containing element sizes.
|
86 |
+
"""
|
87 |
for pool_index, centers in pool_groups.items():
|
88 |
+
# Align elements based on y-coordinates
|
89 |
y_groups = group_centers(centers.copy(), axis=1)
|
90 |
align_y_coordinates(modified_pred, y_groups, class_dict, size)
|
91 |
|
92 |
+
# Recalculate centers after y-alignment and then align based on x-coordinates
|
93 |
centers = recalculate_centers(modified_pred, y_groups)
|
94 |
x_groups = group_centers(centers.copy(), axis=0)
|
95 |
align_x_coordinates(modified_pred, x_groups, class_dict, size)
|
96 |
|
97 |
# Align the y-coordinates of the centers of grouped bounding boxes
|
98 |
def align_y_coordinates(modified_pred, y_groups, class_dict, size):
|
99 |
+
"""
|
100 |
+
Align the y-coordinates of elements in each group.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
modified_pred (dict): Dictionary containing the modified predictions.
|
104 |
+
y_groups (list): List of groups of centers and their indices, grouped by y-coordinate.
|
105 |
+
class_dict (dict): Dictionary mapping class indices to class names.
|
106 |
+
size (dict): Dictionary containing element sizes.
|
107 |
+
"""
|
108 |
for group in y_groups:
|
109 |
+
avg_y = sum([c[0][1] for c in group]) / len(group) # Compute the average y-coordinate
|
110 |
for (center, idx) in group:
|
111 |
label = class_dict[modified_pred['labels'][idx]]
|
112 |
if label in size:
|
|
|
120 |
|
121 |
# Recalculate centers after alignment
|
122 |
def recalculate_centers(modified_pred, groups):
|
123 |
+
"""
|
124 |
+
Recalculate the centers of bounding boxes after alignment.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
modified_pred (dict): Dictionary containing the modified predictions.
|
128 |
+
groups (list): List of groups of centers and their indices.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
list: List of recalculated centers and their indices.
|
132 |
+
"""
|
133 |
centers = []
|
134 |
for group in groups:
|
135 |
for center, idx in group:
|
136 |
x1, y1, x2, y2 = modified_pred['boxes'][idx]
|
137 |
+
center = [(x1 + x2) / 2, (y1 + y2) / 2] # Recompute the center after alignment
|
138 |
centers.append((center, idx))
|
139 |
return centers
|
140 |
|
141 |
# Align the x-coordinates of the centers of grouped bounding boxes
|
142 |
def align_x_coordinates(modified_pred, x_groups, class_dict, size):
|
143 |
+
"""
|
144 |
+
Align the x-coordinates of elements in each group.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
modified_pred (dict): Dictionary containing the modified predictions.
|
148 |
+
x_groups (list): List of groups of centers and their indices, grouped by x-coordinate.
|
149 |
+
class_dict (dict): Dictionary mapping class indices to class names.
|
150 |
+
size (dict): Dictionary containing element sizes.
|
151 |
+
"""
|
152 |
for group in x_groups:
|
153 |
+
avg_x = sum([c[0][0] for c in group]) / len(group) # Compute the average x-coordinate
|
154 |
for (center, idx) in group:
|
155 |
label = class_dict[modified_pred['labels'][idx]]
|
156 |
if label in size:
|
|
|
164 |
|
165 |
# Expand the pool bounding boxes to fit the aligned elements
|
166 |
def expand_pool_bounding_boxes(modified_pred, size_elements):
|
167 |
+
"""
|
168 |
+
Expand the bounding boxes of pools to fit aligned elements.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
modified_pred (dict): Dictionary containing the modified predictions.
|
172 |
+
size_elements (dict): Dictionary containing element sizes.
|
173 |
+
"""
|
174 |
for idx, (pool_index, keep_elements) in enumerate(modified_pred['pool_dict'].items()):
|
175 |
if len(keep_elements) != 0:
|
176 |
marge = size_elements['task'][1] // 2
|
|
|
190 |
error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
|
191 |
continue
|
192 |
|
193 |
+
# Update the pool bounding box with margin
|
194 |
modified_pred['boxes'][position] = [min_x - marge, min_y - marge//2, min_x + pool_width + marge, min_y + pool_height + marge//2]
|
195 |
|
196 |
# Adjust left and right boundaries of all pools
|
197 |
def adjust_pool_boundaries(modified_pred, pred):
|
198 |
+
"""
|
199 |
+
Adjust the left and right boundaries of all pools to ensure they cover all elements.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
modified_pred (dict): Dictionary containing the modified predictions.
|
203 |
+
pred (dict): Dictionary containing original prediction results.
|
204 |
+
"""
|
205 |
min_left, max_right = 0, 0
|
206 |
for pool_index, element_indices in pred['pool_dict'].items():
|
207 |
position = find_position(pool_index, modified_pred['BPMN_id'])
|
|
|
224 |
x1 = min_left
|
225 |
if x2 < max_right:
|
226 |
x2 = max_right
|
227 |
+
# Update the pool bounding box with adjusted boundaries
|
228 |
modified_pred['boxes'][position] = [x1, y1, x2, y2]
|
229 |
|
230 |
# Main function to align boxes
|
231 |
def align_boxes(pred, size, class_dict):
|
232 |
+
"""
|
233 |
+
Main function to align bounding boxes for the given prediction data.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
pred (dict): Dictionary containing prediction results.
|
237 |
+
size (dict): Dictionary containing element sizes.
|
238 |
+
class_dict (dict): Dictionary mapping class indices to class names.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
list: List of aligned bounding boxes.
|
242 |
+
"""
|
243 |
modified_pred = copy.deepcopy(pred)
|
244 |
pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
|
245 |
align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
|
|
|
250 |
|
251 |
return modified_pred['boxes']
|
252 |
|
|
|
253 |
# Function to create a BPMN XML file from prediction results
|
254 |
def create_XML(full_pred, text_mapping, size_scale, scale):
|
255 |
+
"""
|
256 |
+
Create a BPMN XML file from the prediction results.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
full_pred (dict): Dictionary containing full prediction results.
|
260 |
+
text_mapping (dict): Dictionary mapping BPMN IDs to text labels.
|
261 |
+
size_scale (float): Scaling factor for element sizes.
|
262 |
+
scale (float): Scaling factor for bounding boxes.
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
str: Pretty-printed BPMN XML string.
|
266 |
+
"""
|
267 |
namespaces = {
|
268 |
'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
|
269 |
'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
|
|
|
272 |
'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
|
273 |
}
|
274 |
|
|
|
275 |
definitions = ET.Element('bpmn:definitions', {
|
276 |
'xmlns:xsi': namespaces['xsi'],
|
277 |
'xmlns:bpmn': namespaces['bpmn'],
|
|
|
282 |
'id': "simpleExample"
|
283 |
})
|
284 |
|
|
|
285 |
size_elements = get_size_elements(size_scale)
|
286 |
|
287 |
+
# If there is no pool or lane, create a pool with all elements
|
288 |
if len(full_pred['pool_dict']) == 0 or (len(full_pred['pool_dict']) == 1 and len(next(iter(full_pred['pool_dict'].values()))) == len(full_pred['labels'])):
|
289 |
full_pred, text_mapping = create_big_pool(full_pred, text_mapping, size_elements)
|
290 |
|
291 |
+
# Backup the original box positions
|
292 |
old_boxes = copy.deepcopy(full_pred)
|
293 |
|
294 |
# Create BPMN collaboration element
|
|
|
296 |
|
297 |
# Create BPMN process elements
|
298 |
process = []
|
299 |
+
for idx in range(len(full_pred['pool_dict'].items())):
|
300 |
+
process_id = f'process_{idx+1}'
|
301 |
process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false'))
|
302 |
|
303 |
bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
|
304 |
bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
|
305 |
|
306 |
+
# Rescale and align bounding boxes
|
307 |
full_pred['boxes'] = rescale_boxes(scale, old_boxes['boxes'])
|
308 |
full_pred['boxes'] = align_boxes(full_pred, size_elements, class_dict)
|
|
|
309 |
|
310 |
# Add diagram elements for each pool
|
311 |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
|
|
313 |
pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[pool_index])
|
314 |
|
315 |
position = find_position(pool_index, full_pred['BPMN_id'])
|
|
|
|
|
316 |
if position >= len(full_pred['boxes']):
|
317 |
print("Problem with the index")
|
318 |
continue
|
|
|
322 |
|
323 |
add_diagram_elements(bpmnplane, pool_id, min_x, min_y, pool_width, pool_height)
|
324 |
|
|
|
325 |
# Create BPMN elements for each pool
|
326 |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
327 |
create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
|
|
|
346 |
reparsed = minidom.parseString(rough_string)
|
347 |
pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
|
348 |
|
349 |
+
# Restore the original box positions
|
350 |
full_pred['boxes'] = rescale_boxes(1/scale, full_pred['boxes'])
|
351 |
full_pred['boxes'] = old_boxes
|
352 |
|
|
|
354 |
|
355 |
# Function that creates a single pool with all elements
|
356 |
def create_big_pool(full_pred, text_mapping, size_elements, marge=50):
|
357 |
+
"""
|
358 |
+
Create a single pool containing all elements if no pools or lanes are detected.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
full_pred (dict): Dictionary containing full prediction results.
|
362 |
+
text_mapping (dict): Dictionary mapping BPMN IDs to text labels.
|
363 |
+
size_elements (dict): Dictionary containing element sizes.
|
364 |
+
marge (int, optional): Margin to add around the pool. Defaults to 50.
|
365 |
+
|
366 |
+
Returns:
|
367 |
+
tuple: Updated full_pred and text_mapping.
|
368 |
+
"""
|
369 |
new_pool_index = 'pool_1'
|
370 |
size_elements = get_size_elements(st.session_state.size_scale)
|
371 |
elements_pool = list(range(len(full_pred['boxes'])))
|
372 |
+
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred['boxes'], full_pred['labels'], elements_pool, size_elements)
|
373 |
box = [min_x - marge, min_y - marge//2, max_x + marge, max_y + marge//2]
|
374 |
full_pred['boxes'] = np.append(full_pred['boxes'], [box], axis=0)
|
375 |
full_pred['pool_dict'][new_pool_index] = elements_pool
|
|
|
380 |
|
381 |
# Function that gives the size of the elements
|
382 |
def get_size_elements(size_scale=1):
|
383 |
+
"""
|
384 |
+
Get the sizes of BPMN elements based on the scaling factor.
|
385 |
+
|
386 |
+
Args:
|
387 |
+
size_scale (float, optional): Scaling factor for element sizes. Defaults to 1.
|
388 |
+
|
389 |
+
Returns:
|
390 |
+
dict: Dictionary containing element sizes.
|
391 |
+
"""
|
392 |
size_elements = {
|
393 |
+
'event': (size_scale * 43.2, size_scale * 43.2),
|
394 |
+
'task': (size_scale * 120, size_scale * 96),
|
395 |
+
'message': (size_scale * 43.2, size_scale * 43.2),
|
396 |
+
'messageEvent': (size_scale * 43.2, size_scale * 43.2),
|
397 |
+
'exclusiveGateway': (size_scale * 60, size_scale * 60),
|
398 |
+
'parallelGateway': (size_scale * 60, size_scale * 60),
|
399 |
+
'dataObject': (size_scale * 48, size_scale * 72),
|
400 |
+
'dataStore': (size_scale * 72, size_scale * 72),
|
401 |
+
'subProcess': (size_scale * 144, size_scale * 108),
|
402 |
+
'eventBasedGateway': (size_scale * 60, size_scale * 60),
|
403 |
+
'timerEvent': (size_scale * 48, size_scale * 48),
|
404 |
}
|
405 |
return size_elements
|
406 |
|
407 |
def rescale(scale, boxes):
|
408 |
+
"""
|
409 |
+
Rescale the bounding boxes by a given scaling factor.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
scale (float): Scaling factor.
|
413 |
+
boxes (list): List of bounding boxes.
|
414 |
+
|
415 |
+
Returns:
|
416 |
+
list: Rescaled bounding boxes.
|
417 |
+
"""
|
418 |
for i in range(len(boxes)):
|
419 |
+
boxes[i] = [boxes[i][0] * scale,
|
420 |
+
boxes[i][1] * scale,
|
421 |
+
boxes[i][2] * scale,
|
422 |
+
boxes[i][3] * scale]
|
423 |
return boxes
|
424 |
|
425 |
+
# Function to create the unique BPMN_id
|
426 |
+
def create_BPMN_id(labels, pool_dict):
|
427 |
+
"""
|
428 |
+
Create unique BPMN IDs for each element based on their labels.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
labels (list): List of labels for each element.
|
432 |
+
pool_dict (dict): Dictionary containing pool indices and their elements.
|
433 |
|
434 |
+
Returns:
|
435 |
+
tuple: List of BPMN IDs and updated pool dictionary.
|
436 |
+
"""
|
437 |
+
BPMN_id = [class_dict[labels[i]] for i in range(len(labels))]
|
438 |
|
439 |
data_counter = 1
|
440 |
|
|
|
478 |
else:
|
479 |
BPMN_id[idx] = f'{key}_{enums[key]}'
|
480 |
enums[key] += 1
|
481 |
+
|
482 |
# Update the pool_dict keys with their corresponding BPMN_id values
|
483 |
updated_pool_dict = {}
|
484 |
for key, value in pool_dict.items():
|
|
|
488 |
|
489 |
return BPMN_id, updated_pool_dict
|
490 |
|
|
|
|
|
491 |
def add_diagram_elements(parent, element_id, x, y, width, height):
|
492 |
+
"""
|
493 |
+
Utility to add BPMN diagram notation for elements.
|
494 |
+
|
495 |
+
Args:
|
496 |
+
parent (Element): The parent XML element.
|
497 |
+
element_id (str): The ID of the BPMN element.
|
498 |
+
x (float): The x-coordinate of the element.
|
499 |
+
y (float): The y-coordinate of the element.
|
500 |
+
width (float): The width of the element.
|
501 |
+
height (float): The height of the element.
|
502 |
+
"""
|
503 |
shape = ET.SubElement(parent, 'bpmndi:BPMNShape', attrib={
|
504 |
'bpmnElement': element_id,
|
505 |
'id': element_id + '_di'
|
|
|
512 |
})
|
513 |
|
514 |
def add_diagram_edge(parent, element_id, waypoints):
|
515 |
+
"""
|
516 |
+
Utility to add BPMN diagram notation for sequence flows.
|
517 |
+
|
518 |
+
Args:
|
519 |
+
parent (Element): The parent XML element.
|
520 |
+
element_id (str): The ID of the BPMN element.
|
521 |
+
waypoints (list): List of waypoints for the sequence flow.
|
522 |
+
"""
|
523 |
edge = ET.SubElement(parent, 'bpmndi:BPMNEdge', attrib={
|
524 |
'bpmnElement': element_id,
|
525 |
'id': element_id + '_di'
|
|
|
532 |
'y': str(y)
|
533 |
})
|
534 |
|
|
|
535 |
def check_status(link, keep_elements):
|
536 |
+
"""
|
537 |
+
Check the status of a link in terms of its position within the elements.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
link (tuple): A tuple representing the start and end of the link.
|
541 |
+
keep_elements (list): List of elements to keep.
|
542 |
+
|
543 |
+
Returns:
|
544 |
+
str: Status of the link ('middle', 'start', or 'end').
|
545 |
+
"""
|
546 |
if link[0] in keep_elements and link[1] in keep_elements:
|
547 |
return 'middle'
|
548 |
elif link[0] is None and link[1] in keep_elements:
|
|
|
551 |
return 'end'
|
552 |
else:
|
553 |
return 'middle'
|
554 |
+
|
555 |
def check_data_association(i, links, labels, keep_elements):
|
556 |
+
"""
|
557 |
+
Check data associations for an element.
|
558 |
+
|
559 |
+
Args:
|
560 |
+
i (int): Index of the current element.
|
561 |
+
links (list): List of links between elements.
|
562 |
+
labels (list): List of labels for each element.
|
563 |
+
keep_elements (list): List of elements to keep.
|
564 |
+
|
565 |
+
Returns:
|
566 |
+
tuple: Status and indices of data associations.
|
567 |
+
"""
|
568 |
status, links_idx = [], []
|
569 |
+
for j, (k, l) in enumerate(links):
|
570 |
if labels[j] == list(class_dict.values()).index('dataAssociation'):
|
571 |
+
if k == i:
|
572 |
status.append('output')
|
573 |
links_idx.append(j)
|
574 |
+
elif l == i:
|
575 |
status.append('input')
|
576 |
links_idx.append(j)
|
577 |
|
578 |
return status, links_idx
|
579 |
|
580 |
+
def create_data_Association(bpmn, data, size, element_id, current_idx, source_id, target_id):
|
581 |
+
"""
|
582 |
+
Create a data association in the BPMN diagram.
|
583 |
+
|
584 |
+
Args:
|
585 |
+
bpmn (Element): The parent XML element.
|
586 |
+
data (dict): Dictionary containing prediction results.
|
587 |
+
size (dict): Dictionary containing element sizes.
|
588 |
+
element_id (str): The ID of the BPMN element.
|
589 |
+
current_idx (int): Index of the current element.
|
590 |
+
source_id (str): The source element ID.
|
591 |
+
target_id (str): The target element ID.
|
592 |
+
"""
|
593 |
waypoints = calculate_waypoints(data, size, current_idx, source_id, target_id)
|
594 |
if waypoints is not None:
|
595 |
add_diagram_edge(bpmn, element_id, waypoints)
|
596 |
+
|
597 |
def check_eventBasedGateway(i, links, labels):
|
598 |
+
"""
|
599 |
+
Check event-based gateway for an element.
|
600 |
+
|
601 |
+
Args:
|
602 |
+
i (int): Index of the current element.
|
603 |
+
links (list): List of links between elements.
|
604 |
+
labels (list): List of labels for each element.
|
605 |
+
|
606 |
+
Returns:
|
607 |
+
tuple: Status and indices of event-based gateway.
|
608 |
+
"""
|
609 |
status, links_idx = [], []
|
610 |
+
for j, (k, l) in enumerate(links):
|
611 |
if labels[j] == list(class_dict.values()).index('sequenceFlow'):
|
612 |
+
if k == i:
|
613 |
status.append('output')
|
614 |
links_idx.append(j)
|
615 |
+
elif l == i:
|
616 |
status.append('input')
|
617 |
links_idx.append(j)
|
618 |
|
619 |
return status, links_idx
|
620 |
+
|
621 |
# Function to dynamically create and layout BPMN elements
|
622 |
def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data, keep_elements):
|
623 |
+
"""
|
624 |
+
Dynamically create and layout BPMN elements.
|
625 |
+
|
626 |
+
Args:
|
627 |
+
process (Element): The BPMN process element.
|
628 |
+
bpmnplane (Element): The BPMN plane element.
|
629 |
+
text_mapping (dict): Dictionary mapping BPMN IDs to text labels.
|
630 |
+
definitions (Element): The BPMN definitions element.
|
631 |
+
size (dict): Dictionary containing element sizes.
|
632 |
+
data (dict): Dictionary containing prediction results.
|
633 |
+
keep_elements (list): List of elements to keep.
|
634 |
+
"""
|
635 |
elements = data['BPMN_id']
|
636 |
positions = data['boxes']
|
637 |
links = data['links']
|
|
|
749 |
sub_element = ET.SubElement(element, 'bpmn:eventBasedGateway', id=f'eventBasedGateway_{link_idx}_{gateway_name.split("_")[1]}')
|
750 |
create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], i, element_id, gateway_name)
|
751 |
|
|
|
752 |
add_diagram_elements(bpmnplane, element_id, x, y, size['eventBasedGateway'][0], size['eventBasedGateway'][1])
|
753 |
|
754 |
# Data Object
|
|
|
770 |
add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
|
771 |
|
772 |
def calculate_pool_bounds(boxes, labels, keep_elements, size=None, class_dict=None):
|
773 |
+
"""
|
774 |
+
Calculate the bounding box for a pool.
|
775 |
+
|
776 |
+
Args:
|
777 |
+
boxes (list): List of bounding boxes.
|
778 |
+
labels (list): List of labels for each element.
|
779 |
+
keep_elements (list): List of elements to keep.
|
780 |
+
size (dict, optional): Dictionary containing element sizes. Defaults to None.
|
781 |
+
class_dict (dict, optional): Dictionary mapping class indices to class names. Defaults to None.
|
782 |
+
|
783 |
+
Returns:
|
784 |
+
tuple: Minimum and maximum x and y coordinates of the pool.
|
785 |
+
"""
|
786 |
min_x, min_y = float('inf'), float('inf')
|
787 |
max_x, max_y = float('-inf'), float('-inf')
|
788 |
|
|
|
813 |
|
814 |
return min_x, min_y, max_x, max_y
|
815 |
|
|
|
|
|
816 |
def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_element, target_element):
|
817 |
+
"""
|
818 |
+
Calculate waypoints for connecting elements within a pool.
|
819 |
+
|
820 |
+
Args:
|
821 |
+
idx (int): Index of the current element.
|
822 |
+
data (dict): Dictionary containing prediction results.
|
823 |
+
size (dict): Dictionary containing element sizes.
|
824 |
+
source_idx (int): Index of the source element.
|
825 |
+
target_idx (int): Index of the target element.
|
826 |
+
source_element (str): Source element type.
|
827 |
+
target_element (str): Target element type.
|
828 |
+
|
829 |
+
Returns:
|
830 |
+
list: List of waypoints for the connection.
|
831 |
+
"""
|
832 |
# Get the bounding boxes of the source and target elements
|
833 |
source_box = data['boxes'][source_idx]
|
834 |
target_box = data['boxes'][target_idx]
|
|
|
863 |
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1])]
|
864 |
|
865 |
return waypoints
|
|
|
866 |
def add_curve(waypoints, pos_source, pos_target, threshold=30):
|
867 |
"""
|
868 |
Add a single curve to the sequence flow by introducing a control point.
|
869 |
The control point is added at an offset from the midpoint of the original waypoints.
|
870 |
+
|
871 |
+
Args:
|
872 |
+
waypoints (list): List of waypoints representing the path.
|
873 |
+
pos_source (str): Position of the source element ('left', 'right', 'top', 'bottom').
|
874 |
+
pos_target (str): Position of the target element ('left', 'right', 'top', 'bottom').
|
875 |
+
threshold (int, optional): Minimum distance to consider for adding a curve. Defaults to 30.
|
876 |
+
|
877 |
+
Returns:
|
878 |
+
list: List of waypoints with the added control point if applicable.
|
879 |
"""
|
880 |
if len(waypoints) < 2:
|
881 |
return waypoints
|
|
|
893 |
if abs(start_x - end_x) < threshold or abs(start_y - end_y) < threshold:
|
894 |
return waypoints
|
895 |
|
896 |
+
# Calculate the control point based on source and target positions
|
897 |
if pos_source in pos_horizontal and pos_target in pos_horizontal:
|
898 |
control_point = None
|
899 |
elif pos_source in pos_vertical and pos_target in pos_vertical:
|
|
|
904 |
control_point = (start_x, end_y)
|
905 |
else:
|
906 |
control_point = None
|
|
|
907 |
|
908 |
# Create the curved path
|
909 |
if control_point is not None:
|
|
|
913 |
|
914 |
return curved_waypoints
|
915 |
|
|
|
916 |
def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
917 |
+
"""
|
918 |
+
Calculate waypoints for connecting two elements in the diagram.
|
919 |
+
|
920 |
+
Args:
|
921 |
+
data (dict): Data containing diagram information.
|
922 |
+
size (dict): Dictionary of element sizes.
|
923 |
+
current_idx (int): Index of the current element.
|
924 |
+
source_id (str): ID of the source element.
|
925 |
+
target_id (str): ID of the target element.
|
926 |
+
|
927 |
+
Returns:
|
928 |
+
list: List of waypoints for the connection.
|
929 |
+
"""
|
930 |
best_points = data['best_points'][current_idx]
|
931 |
pos_source = best_points[0]
|
932 |
pos_target = best_points[1]
|
|
|
941 |
if source_idx is None or target_idx is None:
|
942 |
warning()
|
943 |
return None
|
|
|
944 |
|
945 |
name_source = source_id.split('_')[0]
|
946 |
name_target = target_id.split('_')[0]
|
|
|
958 |
warning()
|
959 |
return [(source_x, source_y), (target_x, target_y)]
|
960 |
|
961 |
+
# Adjust the source coordinates based on its position
|
962 |
if pos_source == 'left':
|
963 |
source_x = source_x
|
964 |
source_y += size[name_source][1] / 2
|
|
|
972 |
source_x += size[name_source][0] / 2
|
973 |
source_y += size[name_source][1]
|
974 |
|
975 |
+
# Adjust the target coordinates based on its position
|
976 |
if pos_target == 'left':
|
977 |
target_x = target_x
|
978 |
target_y += size[name_target][1] / 2
|
|
|
996 |
|
997 |
return curved_waypoints
|
998 |
|
|
|
999 |
def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
|
1000 |
+
"""
|
1001 |
+
Create a BPMN flow element (sequence flow or message flow) and add it to the BPMN diagram.
|
1002 |
+
|
1003 |
+
Args:
|
1004 |
+
bpmn (ET.Element): The BPMN diagram element.
|
1005 |
+
text_mapping (dict): Dictionary mapping element IDs to their text labels.
|
1006 |
+
idx (int): Index of the current element.
|
1007 |
+
size (dict): Dictionary of element sizes.
|
1008 |
+
data (dict): Data containing diagram information.
|
1009 |
+
parent (ET.Element): The parent element to which the flow element is added.
|
1010 |
+
message (bool, optional): Whether the flow is a message flow. Defaults to False.
|
1011 |
+
"""
|
1012 |
source_idx, target_idx = data['links'][idx]
|
1013 |
|
1014 |
if source_idx is None or target_idx is None:
|
|
|
1043 |
return
|
1044 |
element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
|
1045 |
add_diagram_edge(bpmn, element_id, waypoints)
|
|
|
|
|
|
modules/train.py
CHANGED
@@ -15,8 +15,6 @@ from tqdm import tqdm
|
|
15 |
from modules.utils import write_results
|
16 |
|
17 |
|
18 |
-
|
19 |
-
|
20 |
def get_arrow_model(num_classes, num_keypoints=2):
|
21 |
"""
|
22 |
Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints.
|
@@ -27,14 +25,6 @@ def get_arrow_model(num_classes, num_keypoints=2):
|
|
27 |
|
28 |
Returns:
|
29 |
- model (torch.nn.Module): The modified Keypoint R-CNN model.
|
30 |
-
|
31 |
-
Steps:
|
32 |
-
1. Load a pre-trained Keypoint R-CNN model with a ResNet-50 backbone and Feature Pyramid Network (FPN).
|
33 |
-
The model is initially configured for the COCO dataset, which includes various object classes and keypoints.
|
34 |
-
2. Replace the box predictor to adjust the number of output classes. The box predictor is responsible for
|
35 |
-
classifying detected regions and predicting their bounding boxes.
|
36 |
-
3. Replace the keypoint predictor to adjust the number of keypoints the model predicts for each object.
|
37 |
-
This is necessary to tailor the model to specific tasks that may have different keypoint structures.
|
38 |
"""
|
39 |
# Load a model pre-trained on COCO, initialized without pre-trained weights
|
40 |
model = keypointrcnn_resnet50_fpn(weights=None)
|
@@ -72,44 +62,60 @@ def get_faster_rcnn_model(num_classes):
|
|
72 |
|
73 |
return model
|
74 |
|
75 |
-
def prepare_model(dict,opti,learning_rate=
|
76 |
-
|
77 |
-
|
78 |
-
model = get_faster_rcnn_model(len(dict))
|
79 |
-
elif model_type == 'arrow':
|
80 |
-
model = get_arrow_model(len(dict),2)
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
87 |
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
#
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.00056, eps=1e-08, betas=(0.9, 0.999))
|
97 |
-
else:
|
98 |
-
print('Optimizer not found')
|
99 |
|
100 |
-
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
|
104 |
-
from torch.optim import AdamW
|
105 |
-
import time
|
106 |
-
from modules.train import write_results
|
107 |
|
108 |
-
import torch
|
109 |
-
import numpy as np
|
110 |
-
from tqdm import tqdm
|
111 |
|
112 |
def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
model.train() # Set the model to evaluation mode
|
114 |
total_loss = 0
|
115 |
|
@@ -174,12 +180,12 @@ def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=Fal
|
|
174 |
avg_loss_keypoints = np.mean(loss_keypoints_list)
|
175 |
|
176 |
if print_losses:
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
|
184 |
return avg_loss
|
185 |
|
@@ -188,206 +194,225 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
|
|
188 |
optimizer, model_to_load=None, change_learning_rate=100, start_key=100,
|
189 |
parameters=None, blur_prob=0.02,
|
190 |
score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
|
191 |
-
information_training='training', start_epoch=0, loss_config=None, model_type
|
192 |
eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
|
270 |
-
|
271 |
-
optimizer.zero_grad()
|
272 |
-
|
273 |
-
loss_dict = model(images, targets)
|
274 |
-
# Inside the training loop where losses are calculated:
|
275 |
-
losses = 0
|
276 |
-
if loss_config is not None:
|
277 |
-
for key, loss in loss_dict.items():
|
278 |
-
if loss_config.get(key, False):
|
279 |
-
if key == 'loss_classifier':
|
280 |
-
loss *= 3
|
281 |
-
losses += loss
|
282 |
-
else:
|
283 |
-
losses = sum(loss for key, loss in loss_dict.items())
|
284 |
-
|
285 |
-
# Collect individual losses
|
286 |
-
if loss_dict['loss_classifier']:
|
287 |
-
loss_classifier_list.append(loss_dict['loss_classifier'].item())
|
288 |
-
else:
|
289 |
-
loss_classifier_list.append(0)
|
290 |
-
|
291 |
-
if loss_dict['loss_box_reg']:
|
292 |
-
loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
|
293 |
-
else:
|
294 |
-
loss_box_reg_list.append(0)
|
295 |
-
|
296 |
-
if loss_dict['loss_objectness']:
|
297 |
-
loss_objectness_list.append(loss_dict['loss_objectness'].item())
|
298 |
-
else:
|
299 |
-
loss_objectness_list.append(0)
|
300 |
-
|
301 |
-
if loss_dict['loss_rpn_box_reg']:
|
302 |
-
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
|
303 |
-
else:
|
304 |
-
loss_rpn_box_reg_list.append(0)
|
305 |
-
|
306 |
-
if 'loss_keypoint' in loss_dict:
|
307 |
-
loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
|
308 |
-
else:
|
309 |
-
loss_keypoints_list.append(0)
|
310 |
-
|
311 |
-
|
312 |
-
losses.backward()
|
313 |
-
optimizer.step()
|
314 |
-
|
315 |
-
total_loss += losses.item()
|
316 |
-
|
317 |
-
# Update the description with the current loss
|
318 |
-
progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}')
|
319 |
-
|
320 |
-
# Calculate average loss
|
321 |
-
avg_loss = total_loss / len(data_loader)
|
322 |
-
|
323 |
-
epoch_avg_losses.append(avg_loss)
|
324 |
-
epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
|
325 |
-
epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
|
326 |
-
epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
|
327 |
-
epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
|
328 |
-
epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
|
329 |
|
|
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
best_model_state = copy.deepcopy(model.state_dict())
|
359 |
-
|
360 |
-
if epoch>0 and f1_score>early_stop_f1_score:
|
361 |
-
same+=1
|
362 |
-
|
363 |
-
epoch_precision.append(precision)
|
364 |
-
epoch_recall.append(recall)
|
365 |
-
epoch_f1_score.append(f1_score)
|
366 |
-
epoch_test_loss.append(avg_test_loss)
|
367 |
-
|
368 |
-
name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
|
369 |
-
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
370 |
-
|
371 |
-
if same >=1 :
|
372 |
-
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
373 |
-
write_results(name_model,metrics_list,start_epoch)
|
374 |
-
break
|
375 |
-
|
376 |
-
if (epoch+1+start_epoch) % 5 == 0:
|
377 |
-
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
378 |
-
model.load_state_dict(best_model_state)
|
379 |
-
write_results(name_model,metrics_list,start_epoch)
|
380 |
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
|
|
|
|
385 |
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
|
393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
from modules.utils import write_results
|
16 |
|
17 |
|
|
|
|
|
18 |
def get_arrow_model(num_classes, num_keypoints=2):
|
19 |
"""
|
20 |
Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints.
|
|
|
25 |
|
26 |
Returns:
|
27 |
- model (torch.nn.Module): The modified Keypoint R-CNN model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
"""
|
29 |
# Load a model pre-trained on COCO, initialized without pre-trained weights
|
30 |
model = keypointrcnn_resnet50_fpn(weights=None)
|
|
|
62 |
|
63 |
return model
|
64 |
|
65 |
+
def prepare_model(dict, opti, learning_rate=0.0003, model_to_load=None, model_type='object'):
|
66 |
+
"""
|
67 |
+
Prepares the model and optimizer for training.
|
|
|
|
|
|
|
68 |
|
69 |
+
Parameters:
|
70 |
+
- dict (dict): Dictionary of classes.
|
71 |
+
- opti (str): Optimizer type ('SGD' or 'Adam').
|
72 |
+
- learning_rate (float): Learning rate for the optimizer.
|
73 |
+
- model_to_load (str, optional): Name of the model to load.
|
74 |
+
- model_type (str): Type of model to prepare ('object' or 'arrow').
|
75 |
|
76 |
+
Returns:
|
77 |
+
- model (torch.nn.Module): The prepared model.
|
78 |
+
- optimizer (torch.optim.Optimizer): The configured optimizer.
|
79 |
+
- device (torch.device): The device (CPU or CUDA) on which to perform training.
|
80 |
+
"""
|
81 |
+
# Adjusted to pass the class_dict directly
|
82 |
+
if model_type == 'object':
|
83 |
+
model = get_faster_rcnn_model(len(dict))
|
84 |
+
elif model_type == 'arrow':
|
85 |
+
model = get_arrow_model(len(dict), 2)
|
86 |
|
87 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
88 |
+
# Load the model weights
|
89 |
+
if model_to_load:
|
90 |
+
model.load_state_dict(torch.load('./models/' + model_to_load + '.pth', map_location=device))
|
91 |
+
print(f"Model '{model_to_load}' loaded")
|
|
|
|
|
|
|
92 |
|
93 |
+
model.to(device)
|
94 |
|
95 |
+
if opti == 'SGD':
|
96 |
+
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
|
97 |
+
elif opti == 'Adam':
|
98 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.00056, eps=1e-08, betas=(0.9, 0.999))
|
99 |
+
else:
|
100 |
+
print('Optimizer not found')
|
101 |
|
102 |
+
return model, optimizer, device
|
|
|
|
|
|
|
103 |
|
|
|
|
|
|
|
104 |
|
105 |
def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
|
106 |
+
"""
|
107 |
+
Evaluate the loss of the model on a validation dataset.
|
108 |
+
|
109 |
+
Parameters:
|
110 |
+
- model (torch.nn.Module): The model to evaluate.
|
111 |
+
- data_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
|
112 |
+
- device (torch.device): Device to perform evaluation on.
|
113 |
+
- loss_config (dict, optional): Configuration specifying which losses to use.
|
114 |
+
- print_losses (bool): Whether to print individual loss components.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
- float: Average loss over the validation dataset.
|
118 |
+
"""
|
119 |
model.train() # Set the model to evaluation mode
|
120 |
total_loss = 0
|
121 |
|
|
|
180 |
avg_loss_keypoints = np.mean(loss_keypoints_list)
|
181 |
|
182 |
if print_losses:
|
183 |
+
print(f"Average Loss: {avg_loss:.4f}")
|
184 |
+
print(f"Average Classifier Loss: {avg_loss_classifier:.4f}")
|
185 |
+
print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}")
|
186 |
+
print(f"Average Objectness Loss: {avg_loss_objectness:.4f}")
|
187 |
+
print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}")
|
188 |
+
print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}")
|
189 |
|
190 |
return avg_loss
|
191 |
|
|
|
194 |
optimizer, model_to_load=None, change_learning_rate=100, start_key=100,
|
195 |
parameters=None, blur_prob=0.02,
|
196 |
score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
|
197 |
+
information_training='training', start_epoch=0, loss_config=None, model_type='object',
|
198 |
eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
|
199 |
+
"""
|
200 |
+
Train the model over a specified number of epochs.
|
201 |
+
|
202 |
+
Parameters:
|
203 |
+
- num_epochs (int): Number of epochs to train for.
|
204 |
+
- model (torch.nn.Module): Model to train.
|
205 |
+
- data_loader (torch.utils.data.DataLoader): DataLoader for the training dataset.
|
206 |
+
- subset_test_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
|
207 |
+
- optimizer (torch.optim.Optimizer): Optimizer to use for training.
|
208 |
+
- model_to_load (str, optional): Name of the model to load.
|
209 |
+
- change_learning_rate (int): Epoch interval to change the learning rate.
|
210 |
+
- start_key (int): Epoch to start training keypoints.
|
211 |
+
- parameters (dict, optional): Additional training parameters.
|
212 |
+
- blur_prob (float): Probability of applying blur augmentation.
|
213 |
+
- score_threshold (float): Score threshold for evaluation.
|
214 |
+
- iou_threshold (float): IoU threshold for evaluation.
|
215 |
+
- early_stop_f1_score (float): F1 score threshold for early stopping.
|
216 |
+
- information_training (str): Information about the training.
|
217 |
+
- start_epoch (int): Starting epoch number.
|
218 |
+
- loss_config (dict, optional): Configuration specifying which losses to use.
|
219 |
+
- model_type (str): Type of model ('object' or 'arrow').
|
220 |
+
- eval_metric (str): Evaluation metric ('f1_score', 'precision', 'recall', or 'loss').
|
221 |
+
- device (torch.device): Device to perform training on.
|
222 |
|
223 |
+
Returns:
|
224 |
+
- model (torch.nn.Module): Trained model.
|
225 |
+
"""
|
226 |
+
model.train()
|
227 |
+
|
228 |
+
if loss_config is None:
|
229 |
+
print('No loss config found, all losses will be used.')
|
230 |
+
else:
|
231 |
+
# Print the list of the losses that will be used
|
232 |
+
print('The following losses will be used: ', end='')
|
233 |
+
for key, value in loss_config.items():
|
234 |
+
if value:
|
235 |
+
print(key, end=", ")
|
236 |
+
print()
|
237 |
+
|
238 |
+
# Initialize lists to store epoch-wise average losses
|
239 |
+
epoch_avg_losses = []
|
240 |
+
epoch_avg_loss_classifier = []
|
241 |
+
epoch_avg_loss_box_reg = []
|
242 |
+
epoch_avg_loss_objectness = []
|
243 |
+
epoch_avg_loss_rpn_box_reg = []
|
244 |
+
epoch_avg_loss_keypoints = []
|
245 |
+
epoch_precision = []
|
246 |
+
epoch_recall = []
|
247 |
+
epoch_f1_score = []
|
248 |
+
epoch_test_loss = []
|
249 |
+
|
250 |
+
start_tot = time.time()
|
251 |
+
best_metrics = -1000
|
252 |
+
best_epoch = 0
|
253 |
+
best_model_state = None
|
254 |
+
same = 0
|
255 |
+
learning_rate = optimizer.param_groups[0]['lr']
|
256 |
+
bad_test_loss = 0
|
257 |
+
previous_test_loss = 1000
|
258 |
+
|
259 |
+
if parameters is not None:
|
260 |
+
batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values()
|
261 |
+
|
262 |
+
print(f"Let's go training {model_type} model with {num_epochs} epochs!")
|
263 |
+
if parameters is not None:
|
264 |
+
print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}")
|
265 |
+
|
266 |
+
for epoch in range(num_epochs):
|
267 |
+
if (epoch > 0 and (epoch) % change_learning_rate == 0) or bad_test_loss >= 3:
|
268 |
+
learning_rate = 0.7 * learning_rate
|
269 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
|
270 |
+
if best_model_state is not None:
|
271 |
+
model.load_state_dict(best_model_state)
|
272 |
+
print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
|
273 |
+
bad_test_loss = 0
|
274 |
+
if epoch > 0 and (epoch) == start_key:
|
275 |
+
print("Now it's training Keypoints also")
|
276 |
+
loss_config['loss_keypoint'] = True
|
277 |
+
for name, param in model.named_parameters():
|
278 |
+
if 'keypoint' in name:
|
279 |
+
param.requires_grad = True
|
280 |
+
|
281 |
+
model.train()
|
282 |
+
start = time.time()
|
283 |
+
total_loss = 0
|
284 |
+
|
285 |
+
# Initialize lists to keep track of individual losses
|
286 |
+
loss_classifier_list = []
|
287 |
+
loss_box_reg_list = []
|
288 |
+
loss_objectness_list = []
|
289 |
+
loss_rpn_box_reg_list = []
|
290 |
+
loss_keypoints_list = []
|
291 |
+
|
292 |
+
# Create a tqdm progress bar
|
293 |
+
progress_bar = tqdm(data_loader, desc=f'Epoch {epoch + 1 + start_epoch}')
|
294 |
+
|
295 |
+
for images, targets_im in progress_bar:
|
296 |
+
images = [image.to(device) for image in images]
|
297 |
+
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
+
optimizer.zero_grad()
|
300 |
|
301 |
+
loss_dict = model(images, targets)
|
302 |
+
# Inside the training loop where losses are calculated:
|
303 |
+
losses = 0
|
304 |
+
if loss_config is not None:
|
305 |
+
for key, loss in loss_dict.items():
|
306 |
+
if loss_config.get(key, False):
|
307 |
+
if key == 'loss_classifier':
|
308 |
+
loss *= 3
|
309 |
+
losses += loss
|
310 |
+
else:
|
311 |
+
losses = sum(loss for key, loss in loss_dict.items())
|
312 |
+
|
313 |
+
# Collect individual losses
|
314 |
+
if loss_dict['loss_classifier']:
|
315 |
+
loss_classifier_list.append(loss_dict['loss_classifier'].item())
|
316 |
+
else:
|
317 |
+
loss_classifier_list.append(0)
|
318 |
+
|
319 |
+
if loss_dict['loss_box_reg']:
|
320 |
+
loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
|
321 |
+
else:
|
322 |
+
loss_box_reg_list.append(0)
|
323 |
+
|
324 |
+
if loss_dict['loss_objectness']:
|
325 |
+
loss_objectness_list.append(loss_dict['loss_objectness'].item())
|
326 |
+
else:
|
327 |
+
loss_objectness_list.append(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
+
if loss_dict['loss_rpn_box_reg']:
|
330 |
+
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
|
331 |
+
else:
|
332 |
+
loss_rpn_box_reg_list.append(0)
|
333 |
+
|
334 |
+
if 'loss_keypoint' in loss_dict:
|
335 |
+
loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
|
336 |
+
else:
|
337 |
+
loss_keypoints_list.append(0)
|
338 |
|
339 |
+
losses.backward()
|
340 |
+
optimizer.step()
|
341 |
|
342 |
+
total_loss += losses.item()
|
343 |
+
|
344 |
+
# Update the description with the current loss
|
345 |
+
progress_bar.set_description(f'Epoch {epoch + 1 + start_epoch}, Loss: {losses.item():.4f}')
|
346 |
+
|
347 |
+
# Calculate average loss
|
348 |
+
avg_loss = total_loss / len(data_loader)
|
349 |
+
|
350 |
+
epoch_avg_losses.append(avg_loss)
|
351 |
+
epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
|
352 |
+
epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
|
353 |
+
epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
|
354 |
+
epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
|
355 |
+
epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
|
356 |
|
357 |
+
# Evaluate the model on the test set
|
358 |
+
if eval_metric == 'loss':
|
359 |
+
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0, 0, 0, 0, 0, 0
|
360 |
+
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
361 |
+
print(f"Epoch {epoch + 1 + start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
362 |
+
else:
|
363 |
+
avg_test_loss = 0
|
364 |
+
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
|
365 |
+
print(f"Epoch {epoch + 1 + start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
|
366 |
+
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
367 |
+
print(f"Epoch {epoch + 1 + start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
368 |
+
|
369 |
+
print(f"Time: {time.time() - start:.2f} [s]")
|
370 |
+
|
371 |
+
if eval_metric == 'f1_score':
|
372 |
+
metric_used = f1_score
|
373 |
+
elif eval_metric == 'precision':
|
374 |
+
metric_used = precision
|
375 |
+
elif eval_metric == 'recall':
|
376 |
+
metric_used = recall
|
377 |
+
else:
|
378 |
+
metric_used = -avg_test_loss
|
379 |
+
|
380 |
+
# Check if this epoch's model has the lowest average loss
|
381 |
+
if metric_used > best_metrics:
|
382 |
+
best_metrics = metric_used
|
383 |
+
best_epoch = epoch + 1 + start_epoch
|
384 |
+
best_model_state = copy.deepcopy(model.state_dict())
|
385 |
+
|
386 |
+
if epoch > 0 and f1_score > early_stop_f1_score:
|
387 |
+
same += 1
|
388 |
+
|
389 |
+
epoch_precision.append(precision)
|
390 |
+
epoch_recall.append(recall)
|
391 |
+
epoch_f1_score.append(f1_score)
|
392 |
+
epoch_test_loss.append(avg_test_loss)
|
393 |
+
|
394 |
+
name_model = f"model_{type(optimizer).__name__}_{epoch + 1 + start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob * 10)}_crop0{int(crop_prob * 10)}_flip0{int(h_flip_prob * 10)}_rotate0{int(rotate_proba * 10)}_{information_training}"
|
395 |
+
metrics_list = [epoch_avg_losses, epoch_avg_loss_classifier, epoch_avg_loss_box_reg, epoch_avg_loss_objectness, epoch_avg_loss_rpn_box_reg, epoch_avg_loss_keypoints, epoch_precision, epoch_recall, epoch_f1_score, epoch_test_loss]
|
396 |
+
|
397 |
+
if same >= 1:
|
398 |
+
torch.save(best_model_state, './models/' + name_model + '.pth')
|
399 |
+
write_results(name_model, metrics_list, start_epoch)
|
400 |
+
break
|
401 |
+
|
402 |
+
if (epoch + 1 + start_epoch) % 5 == 0:
|
403 |
+
torch.save(best_model_state, './models/' + name_model + '.pth')
|
404 |
+
model.load_state_dict(best_model_state)
|
405 |
+
write_results(name_model, metrics_list, start_epoch)
|
406 |
+
|
407 |
+
if avg_test_loss > previous_test_loss:
|
408 |
+
bad_test_loss += 1
|
409 |
+
previous_test_loss = avg_test_loss
|
410 |
+
|
411 |
+
print(f"\n Total time: {(time.time() - start_tot) / 60} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metrics:.4f}")
|
412 |
+
if best_model_state:
|
413 |
+
torch.save(best_model_state, './models/' + name_model + '.pth')
|
414 |
+
model.load_state_dict(best_model_state)
|
415 |
+
write_results(name_model, metrics_list, start_epoch)
|
416 |
+
print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch + 1 + start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob * 10)}_crop0{int(crop_prob * 10)}_flip0{int(h_flip_prob * 10)}_rotate0{int(rotate_proba * 10)}_{information_training}")
|
417 |
+
|
418 |
+
return model
|
modules/utils.py
CHANGED
@@ -1,59 +1,11 @@
|
|
1 |
-
from torchvision.models.detection import keypointrcnn_resnet50_fpn
|
2 |
-
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
3 |
-
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
|
4 |
-
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
|
5 |
-
import random
|
6 |
import torch
|
7 |
-
from torch.utils.data import Dataset
|
8 |
import torchvision.transforms.functional as F
|
9 |
import numpy as np
|
10 |
-
from torch.utils.data.dataloader import default_collate
|
11 |
import cv2
|
12 |
import matplotlib.pyplot as plt
|
13 |
-
from torch.utils.data import DataLoader, Subset, ConcatDataset
|
14 |
import streamlit as st
|
15 |
|
16 |
-
|
17 |
-
"""object_dict = {
|
18 |
-
0: 'background',
|
19 |
-
1: 'task',
|
20 |
-
2: 'exclusiveGateway',
|
21 |
-
3: 'eventBasedGateway',
|
22 |
-
4: 'event',
|
23 |
-
5: 'messageEvent',
|
24 |
-
6: 'timerEvent',
|
25 |
-
7: 'dataObject',
|
26 |
-
8: 'dataStore',
|
27 |
-
9: 'pool',
|
28 |
-
10: 'lane',
|
29 |
-
}
|
30 |
-
|
31 |
-
|
32 |
-
arrow_dict = {
|
33 |
-
0: 'background',
|
34 |
-
1: 'sequenceFlow',
|
35 |
-
2: 'dataAssociation',
|
36 |
-
3: 'messageFlow',
|
37 |
-
}
|
38 |
-
|
39 |
-
class_dict = {
|
40 |
-
0: 'background',
|
41 |
-
1: 'task',
|
42 |
-
2: 'exclusiveGateway',
|
43 |
-
3: 'eventBasedGateway',
|
44 |
-
4: 'event',
|
45 |
-
5: 'messageEvent',
|
46 |
-
6: 'timerEvent',
|
47 |
-
7: 'dataObject',
|
48 |
-
8: 'dataStore',
|
49 |
-
9: 'pool',
|
50 |
-
10: 'lane',
|
51 |
-
11: 'sequenceFlow',
|
52 |
-
12: 'dataAssociation',
|
53 |
-
13: 'messageFlow',
|
54 |
-
}"""
|
55 |
-
|
56 |
-
|
57 |
object_dict = {
|
58 |
0: 'background',
|
59 |
1: 'task',
|
@@ -96,7 +48,6 @@ class_dict = {
|
|
96 |
15: 'messageFlow',
|
97 |
}
|
98 |
|
99 |
-
|
100 |
def is_inside(box1, box2):
|
101 |
"""Check if the center of box1 is inside box2."""
|
102 |
x_center = (box1[0] + box1[2]) / 2
|
@@ -107,51 +58,31 @@ def is_vertical(box):
|
|
107 |
"""Determine if the text in the bounding box is vertically aligned."""
|
108 |
width = box[2] - box[0]
|
109 |
height = box[3] - box[1]
|
110 |
-
return (height > 2*width)
|
111 |
|
112 |
def rescale_boxes(scale, boxes):
|
|
|
113 |
for i in range(len(boxes)):
|
114 |
-
|
115 |
-
boxes[i][1]*scale,
|
116 |
-
boxes[i][2]*scale,
|
117 |
-
boxes[i][3]*scale]
|
118 |
return boxes
|
119 |
|
120 |
def iou(box1, box2):
|
121 |
-
|
122 |
inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
|
123 |
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
124 |
-
|
125 |
-
# Calcule l'union des deux boîtes englobantes
|
126 |
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
127 |
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
128 |
union_area = box1_area + box2_area - inter_area
|
129 |
-
|
130 |
return inter_area / union_area
|
131 |
|
132 |
def proportion_inside(box1, box2):
|
133 |
-
|
134 |
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
135 |
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
136 |
-
|
137 |
-
# Determine the bigger and smaller boxes
|
138 |
-
if box1_area > box2_area:
|
139 |
-
big_box = box1
|
140 |
-
small_box = box2
|
141 |
-
else:
|
142 |
-
big_box = box2
|
143 |
-
small_box = box1
|
144 |
-
|
145 |
-
# Calculate the intersection of the two bounding boxes
|
146 |
inter_box = [max(small_box[0], big_box[0]), max(small_box[1], big_box[1]), min(small_box[2], big_box[2]), min(small_box[3], big_box[3])]
|
147 |
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
148 |
-
|
149 |
-
# Calculate the proportion of the smaller box inside the bigger box
|
150 |
-
if (small_box[2] - small_box[0]) * (small_box[3] - small_box[1]) == 0:
|
151 |
-
return 0
|
152 |
proportion = inter_area / ((small_box[2] - small_box[0]) * (small_box[3] - small_box[1]))
|
153 |
-
|
154 |
-
# Ensure the proportion is at most 100%
|
155 |
return min(proportion, 1.0)
|
156 |
|
157 |
def resize_boxes(boxes, original_size, target_size):
|
@@ -168,20 +99,15 @@ def resize_boxes(boxes, original_size, target_size):
|
|
168 |
"""
|
169 |
orig_width, orig_height = original_size
|
170 |
target_width, target_height = target_size
|
171 |
-
|
172 |
-
# Calculate the ratios for width and height
|
173 |
width_ratio = target_width / orig_width
|
174 |
height_ratio = target_height / orig_height
|
175 |
-
|
176 |
-
# Apply the ratios to the bounding boxes
|
177 |
boxes[:, 0] *= width_ratio
|
178 |
boxes[:, 1] *= height_ratio
|
179 |
boxes[:, 2] *= width_ratio
|
180 |
boxes[:, 3] *= height_ratio
|
181 |
-
|
182 |
return boxes
|
183 |
|
184 |
-
def resize_keypoints(keypoints
|
185 |
"""
|
186 |
Resize keypoints based on the original and target dimensions of an image.
|
187 |
|
@@ -192,40 +118,38 @@ def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: t
|
|
192 |
|
193 |
Returns:
|
194 |
- np.ndarray: The resized keypoints.
|
195 |
-
|
196 |
-
Explanation:
|
197 |
-
The function calculates the ratio of the target dimensions to the original dimensions.
|
198 |
-
It then applies these ratios to the x and y coordinates of each keypoint to scale them
|
199 |
-
appropriately to the target image size.
|
200 |
"""
|
201 |
-
|
202 |
orig_width, orig_height = original_size
|
203 |
target_width, target_height = target_size
|
204 |
-
|
205 |
-
# Calculate the ratios for width and height scaling
|
206 |
width_ratio = target_width / orig_width
|
207 |
height_ratio = target_height / orig_height
|
208 |
-
|
209 |
-
|
210 |
-
keypoints[:, 0] *= width_ratio # Scale x coordinates
|
211 |
-
keypoints[:, 1] *= height_ratio # Scale y coordinates
|
212 |
-
|
213 |
return keypoints
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
for i in range(len(metrics_list[0])):
|
219 |
-
|
220 |
-
|
221 |
|
222 |
def find_other_keypoint(idx, keypoints, boxes):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
box = boxes[idx]
|
224 |
-
key1,key2 = keypoints[idx]
|
225 |
x1, y1, x2, y2 = box
|
226 |
center = ((x1 + x2) // 2, (y1 + y2) // 2)
|
227 |
average_keypoint = (key1 + key2) // 2
|
228 |
-
#find the opposite keypoint to the center
|
229 |
if average_keypoint[0] < center[0]:
|
230 |
x = center[0] + abs(center[0] - average_keypoint[0])
|
231 |
else:
|
@@ -235,7 +159,6 @@ def find_other_keypoint(idx, keypoints, boxes):
|
|
235 |
else:
|
236 |
y = center[1] - abs(center[1] - average_keypoint[1])
|
237 |
return x, y, average_keypoint[0], average_keypoint[1]
|
238 |
-
|
239 |
|
240 |
def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
|
241 |
"""
|
@@ -251,47 +174,28 @@ def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
|
|
251 |
Returns:
|
252 |
- tuple: Filtered boxes, scores, labels, and keypoints.
|
253 |
"""
|
254 |
-
# Calculate the area of each bounding box to use in IoU calculation.
|
255 |
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
256 |
-
|
257 |
-
# Sort the indices of the boxes based on their scores in descending order.
|
258 |
order = scores.argsort()[::-1]
|
259 |
-
|
260 |
-
keep = [] # List to store indices of boxes to keep.
|
261 |
-
|
262 |
while order.size > 0:
|
263 |
-
# Take the first index (highest score) from the sorted list.
|
264 |
i = order[0]
|
265 |
-
keep.append(i)
|
266 |
-
|
267 |
-
# Compute the coordinates of the intersection rectangle.
|
268 |
xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
|
269 |
yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
|
270 |
xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
|
271 |
yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
|
272 |
-
|
273 |
-
# Compute the area of the intersection rectangle.
|
274 |
w = np.maximum(0.0, xx2 - xx1)
|
275 |
h = np.maximum(0.0, yy2 - yy1)
|
276 |
inter = w * h
|
277 |
-
|
278 |
-
# Calculate IoU and find boxes with IoU less than the threshold to keep.
|
279 |
iou = inter / (areas[i] + areas[order[1:]] - inter)
|
280 |
inds = np.where(iou <= iou_threshold)[0]
|
281 |
-
|
282 |
-
# Update the list of box indices to consider in the next iteration.
|
283 |
-
order = order[inds + 1] # Skip the first element since it's already included in 'keep'.
|
284 |
-
|
285 |
-
# Use the indices in 'keep' to select the boxes, scores, labels, and keypoints to return.
|
286 |
boxes = boxes[keep]
|
287 |
scores = scores[keep]
|
288 |
labels = labels[keep]
|
289 |
keypoints = keypoints[keep]
|
290 |
-
|
291 |
return boxes, scores, labels, keypoints
|
292 |
|
293 |
-
|
294 |
-
|
295 |
def draw_annotations(image,
|
296 |
target=None,
|
297 |
prediction=None,
|
@@ -312,7 +216,7 @@ def draw_annotations(image,
|
|
312 |
only_print=None,
|
313 |
axis=False,
|
314 |
return_image=False,
|
315 |
-
new_size=(1333,800),
|
316 |
resize=False):
|
317 |
"""
|
318 |
Draws annotations on images including bounding boxes, keypoints, links, and text.
|
@@ -328,7 +232,7 @@ def draw_annotations(image,
|
|
328 |
- draw_boxes (bool): Flag to draw bounding boxes.
|
329 |
- draw_text (bool): Flag to draw text annotations.
|
330 |
- draw_links (bool): Flag to draw links between annotations.
|
331 |
-
- draw_twins (bool): Flag to draw
|
332 |
- write_class (bool): Flag to write class names near the annotations.
|
333 |
- write_score (bool): Flag to write scores near the annotations.
|
334 |
- write_text (bool): Flag to write OCR recognized text.
|
@@ -345,137 +249,119 @@ def draw_annotations(image,
|
|
345 |
image_copy = image.copy()
|
346 |
scale = max(image.shape[0], image.shape[1]) / 1000
|
347 |
|
348 |
-
#
|
349 |
-
def draw(data,is_prediction=False):
|
350 |
-
""" Helper function to draw annotations based on provided data. """
|
351 |
-
|
352 |
for i in range(len(data['boxes'])):
|
|
|
|
|
|
|
|
|
353 |
if is_prediction:
|
354 |
-
box = data['boxes'][i].tolist()
|
355 |
-
x1, y1, x2, y2 = box
|
356 |
-
if resize:
|
357 |
-
x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
|
358 |
score = data['scores'][i].item()
|
359 |
if score < score_threshold:
|
360 |
continue
|
361 |
-
else:
|
362 |
-
box = data['boxes'][i].tolist()
|
363 |
-
x1, y1, x2, y2 = box
|
364 |
if draw_boxes:
|
365 |
if only_print is not None:
|
366 |
if data['labels'][i] != list(model_dict.values()).index(only_print):
|
367 |
continue
|
368 |
-
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2*scale))
|
369 |
if is_prediction and write_score:
|
370 |
-
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
|
371 |
|
372 |
if write_class and 'labels' in data:
|
373 |
class_id = data['labels'][i].item()
|
374 |
-
cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
|
375 |
|
376 |
if write_idx:
|
377 |
-
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
|
378 |
-
|
379 |
|
380 |
# Draw keypoints if available
|
381 |
if draw_keypoints and 'keypoints' in data:
|
382 |
if is_prediction and keypoints_correction:
|
383 |
for idx, (key1, key2) in enumerate(data['keypoints']):
|
384 |
if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
385 |
-
|
386 |
-
|
387 |
continue
|
388 |
-
# Calculate the Euclidean distance between the two keypoints
|
389 |
distance = np.linalg.norm(key1[:2] - key2[:2])
|
390 |
-
|
391 |
if distance < 5:
|
392 |
-
x_new,y_new, x,y = find_other_keypoint(idx, data['keypoints'], data['boxes'])
|
393 |
-
data['keypoints'][idx][0] = torch.tensor([x_new, y_new,1])
|
394 |
-
data['keypoints'][idx][1] = torch.tensor([x, y,1])
|
395 |
print("keypoint has been changed")
|
396 |
for i in range(len(data['keypoints'])):
|
397 |
kp = data['keypoints'][i]
|
398 |
for j in range(kp.shape[0]):
|
399 |
-
if is_prediction and data['labels'][i]
|
|
|
|
|
400 |
continue
|
401 |
if is_prediction:
|
402 |
score = data['scores'][i]
|
403 |
if score < score_threshold:
|
404 |
continue
|
405 |
-
x,y,v = np.array(kp[j])
|
406 |
if resize:
|
407 |
-
x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
|
408 |
if j == 0:
|
409 |
-
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
|
410 |
else:
|
411 |
-
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
|
412 |
|
413 |
# Draw text predictions if available
|
414 |
-
if (draw_text or write_text) and text_predictions is not None:
|
415 |
for i in range(len(text_predictions[0])):
|
416 |
x1, y1, x2, y2 = text_predictions[0][i]
|
417 |
text = text_predictions[1][i]
|
418 |
if resize:
|
419 |
-
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
|
420 |
if draw_text:
|
421 |
-
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
|
422 |
if write_text:
|
423 |
-
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2)
|
424 |
-
|
425 |
def draw_with_links(full_prediction):
|
426 |
-
|
427 |
-
#check if keypoints detected are the same
|
428 |
if draw_twins and full_prediction is not None:
|
429 |
-
|
430 |
-
|
431 |
-
circle_radius = int(10 * scale) # Circle radius scaled by image scale
|
432 |
-
|
433 |
for idx, (key1, key2) in enumerate(full_prediction['keypoints']):
|
434 |
if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
435 |
-
|
436 |
-
|
437 |
continue
|
438 |
-
# Calculate the Euclidean distance between the two keypoints
|
439 |
distance = np.linalg.norm(key1[:2] - key2[:2])
|
440 |
if distance < 10:
|
441 |
-
x_new,y_new, x,y = find_other_keypoint(idx,full_prediction)
|
442 |
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
|
443 |
-
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
|
444 |
|
445 |
-
|
446 |
-
if draw_links==True and full_prediction is not None:
|
447 |
for i, (start_idx, end_idx) in enumerate(full_prediction['links']):
|
448 |
if start_idx is None or end_idx is None:
|
449 |
continue
|
450 |
start_box = full_prediction['boxes'][start_idx]
|
451 |
end_box = full_prediction['boxes'][end_idx]
|
452 |
current_box = full_prediction['boxes'][i]
|
453 |
-
# Calculate the center of each bounding box
|
454 |
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
|
455 |
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
|
456 |
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
|
457 |
-
|
458 |
-
cv2.line(image_copy, (int(
|
459 |
-
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
|
460 |
|
461 |
-
i+=1
|
462 |
|
463 |
-
# Draw GT annotations
|
464 |
if target is not None:
|
465 |
draw(target, is_prediction=False)
|
466 |
-
# Draw predictions
|
467 |
if prediction is not None:
|
468 |
-
#prediction = prediction[0]
|
469 |
draw(prediction, is_prediction=True)
|
470 |
-
# Draw links with full predictions
|
471 |
if full_prediction is not None:
|
472 |
draw_with_links(full_prediction)
|
473 |
|
474 |
-
# Display the image
|
475 |
image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
|
476 |
plt.figure(figsize=(12, 12))
|
477 |
plt.imshow(image_copy)
|
478 |
-
if axis
|
479 |
plt.axis('off')
|
480 |
plt.show()
|
481 |
|
@@ -496,28 +382,24 @@ def find_closest_object(keypoint, boxes, labels):
|
|
496 |
closest_object_idx = None
|
497 |
best_point = None
|
498 |
min_distance = float('inf')
|
499 |
-
# Iterate over each bounding box
|
500 |
for i, box in enumerate(boxes):
|
501 |
if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
|
502 |
list(class_dict.values()).index('messageFlow'),
|
503 |
list(class_dict.values()).index('dataAssociation'),
|
504 |
-
#list(class_dict.values()).index('pool'),
|
505 |
list(class_dict.values()).index('lane')]:
|
506 |
continue
|
507 |
x1, y1, x2, y2 = box
|
508 |
|
509 |
-
top = ((x1+x2)/2, y1)
|
510 |
-
bottom = ((x1+x2)/2, y2)
|
511 |
-
left = (x1, (y1+y2)/2)
|
512 |
-
right = (x2, (y1+y2)/2)
|
513 |
-
points = [left, top
|
514 |
|
515 |
-
pos_dict = {0:'left', 1:'top', 2:'right', 3:'bottom'}
|
516 |
|
517 |
-
|
518 |
-
for pos, (point) in enumerate(points):
|
519 |
distance = np.linalg.norm(keypoint[:2] - point)
|
520 |
-
# Update the closest object index if this object is closer
|
521 |
if distance < min_distance:
|
522 |
min_distance = distance
|
523 |
closest_object_idx = i
|
@@ -525,9 +407,10 @@ def find_closest_object(keypoint, boxes, labels):
|
|
525 |
|
526 |
return closest_object_idx, best_point
|
527 |
|
528 |
-
|
529 |
def error(text='There is an error in the detection'):
|
|
|
530 |
st.error(text, icon="🚨")
|
531 |
|
532 |
def warning(text='Some element are maybe not detected, verify the results, try to modify the parameters or try to add it in the method and style step.'):
|
|
|
533 |
st.warning(text, icon="⚠️")
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
|
|
2 |
import torchvision.transforms.functional as F
|
3 |
import numpy as np
|
|
|
4 |
import cv2
|
5 |
import matplotlib.pyplot as plt
|
|
|
6 |
import streamlit as st
|
7 |
|
8 |
+
# Define dictionaries to map class indices to their corresponding names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
object_dict = {
|
10 |
0: 'background',
|
11 |
1: 'task',
|
|
|
48 |
15: 'messageFlow',
|
49 |
}
|
50 |
|
|
|
51 |
def is_inside(box1, box2):
|
52 |
"""Check if the center of box1 is inside box2."""
|
53 |
x_center = (box1[0] + box1[2]) / 2
|
|
|
58 |
"""Determine if the text in the bounding box is vertically aligned."""
|
59 |
width = box[2] - box[0]
|
60 |
height = box[3] - box[1]
|
61 |
+
return (height > 2 * width)
|
62 |
|
63 |
def rescale_boxes(scale, boxes):
|
64 |
+
"""Rescale the bounding boxes by a given scale factor."""
|
65 |
for i in range(len(boxes)):
|
66 |
+
boxes[i] = [boxes[i][0] * scale, boxes[i][1] * scale, boxes[i][2] * scale, boxes[i][3] * scale]
|
|
|
|
|
|
|
67 |
return boxes
|
68 |
|
69 |
def iou(box1, box2):
|
70 |
+
"""Calculate the Intersection over Union (IoU) of two bounding boxes."""
|
71 |
inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
|
72 |
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
|
|
|
|
73 |
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
74 |
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
75 |
union_area = box1_area + box2_area - inter_area
|
|
|
76 |
return inter_area / union_area
|
77 |
|
78 |
def proportion_inside(box1, box2):
|
79 |
+
"""Calculate the proportion of the smaller box inside the larger box."""
|
80 |
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
81 |
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
82 |
+
big_box, small_box = (box1, box2) if box1_area > box2_area else (box2, box1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
inter_box = [max(small_box[0], big_box[0]), max(small_box[1], big_box[1]), min(small_box[2], big_box[2]), min(small_box[3], big_box[3])]
|
84 |
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
|
|
|
|
|
|
|
|
85 |
proportion = inter_area / ((small_box[2] - small_box[0]) * (small_box[3] - small_box[1]))
|
|
|
|
|
86 |
return min(proportion, 1.0)
|
87 |
|
88 |
def resize_boxes(boxes, original_size, target_size):
|
|
|
99 |
"""
|
100 |
orig_width, orig_height = original_size
|
101 |
target_width, target_height = target_size
|
|
|
|
|
102 |
width_ratio = target_width / orig_width
|
103 |
height_ratio = target_height / orig_height
|
|
|
|
|
104 |
boxes[:, 0] *= width_ratio
|
105 |
boxes[:, 1] *= height_ratio
|
106 |
boxes[:, 2] *= width_ratio
|
107 |
boxes[:, 3] *= height_ratio
|
|
|
108 |
return boxes
|
109 |
|
110 |
+
def resize_keypoints(keypoints, original_size, target_size):
|
111 |
"""
|
112 |
Resize keypoints based on the original and target dimensions of an image.
|
113 |
|
|
|
118 |
|
119 |
Returns:
|
120 |
- np.ndarray: The resized keypoints.
|
|
|
|
|
|
|
|
|
|
|
121 |
"""
|
|
|
122 |
orig_width, orig_height = original_size
|
123 |
target_width, target_height = target_size
|
|
|
|
|
124 |
width_ratio = target_width / orig_width
|
125 |
height_ratio = target_height / orig_height
|
126 |
+
keypoints[:, 0] *= width_ratio
|
127 |
+
keypoints[:, 1] *= height_ratio
|
|
|
|
|
|
|
128 |
return keypoints
|
129 |
|
130 |
+
def write_results(name_model, metrics_list, start_epoch):
|
131 |
+
"""Write training results to a text file."""
|
132 |
+
with open('./results/' + name_model + '.txt', 'w') as f:
|
133 |
for i in range(len(metrics_list[0])):
|
134 |
+
f.write(f"{i + 1 + start_epoch},{metrics_list[0][i]},{metrics_list[1][i]},{metrics_list[2][i]},{metrics_list[3][i]},{metrics_list[4][i]},{metrics_list[5][i]},{metrics_list[6][i]},{metrics_list[7][i]},{metrics_list[8][i]},{metrics_list[9][i]} \n")
|
|
|
135 |
|
136 |
def find_other_keypoint(idx, keypoints, boxes):
|
137 |
+
"""
|
138 |
+
Find the opposite keypoint to the center of the box.
|
139 |
+
|
140 |
+
Parameters:
|
141 |
+
- idx (int): The index of the box and keypoints.
|
142 |
+
- keypoints (np.ndarray): The array of keypoints.
|
143 |
+
- boxes (np.ndarray): The array of bounding boxes.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
- tuple: The coordinates of the new keypoint and the average keypoint.
|
147 |
+
"""
|
148 |
box = boxes[idx]
|
149 |
+
key1, key2 = keypoints[idx]
|
150 |
x1, y1, x2, y2 = box
|
151 |
center = ((x1 + x2) // 2, (y1 + y2) // 2)
|
152 |
average_keypoint = (key1 + key2) // 2
|
|
|
153 |
if average_keypoint[0] < center[0]:
|
154 |
x = center[0] + abs(center[0] - average_keypoint[0])
|
155 |
else:
|
|
|
159 |
else:
|
160 |
y = center[1] - abs(center[1] - average_keypoint[1])
|
161 |
return x, y, average_keypoint[0], average_keypoint[1]
|
|
|
162 |
|
163 |
def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
|
164 |
"""
|
|
|
174 |
Returns:
|
175 |
- tuple: Filtered boxes, scores, labels, and keypoints.
|
176 |
"""
|
|
|
177 |
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
|
|
|
|
178 |
order = scores.argsort()[::-1]
|
179 |
+
keep = []
|
|
|
|
|
180 |
while order.size > 0:
|
|
|
181 |
i = order[0]
|
182 |
+
keep.append(i)
|
|
|
|
|
183 |
xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
|
184 |
yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
|
185 |
xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
|
186 |
yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
|
|
|
|
|
187 |
w = np.maximum(0.0, xx2 - xx1)
|
188 |
h = np.maximum(0.0, yy2 - yy1)
|
189 |
inter = w * h
|
|
|
|
|
190 |
iou = inter / (areas[i] + areas[order[1:]] - inter)
|
191 |
inds = np.where(iou <= iou_threshold)[0]
|
192 |
+
order = order[inds + 1]
|
|
|
|
|
|
|
|
|
193 |
boxes = boxes[keep]
|
194 |
scores = scores[keep]
|
195 |
labels = labels[keep]
|
196 |
keypoints = keypoints[keep]
|
|
|
197 |
return boxes, scores, labels, keypoints
|
198 |
|
|
|
|
|
199 |
def draw_annotations(image,
|
200 |
target=None,
|
201 |
prediction=None,
|
|
|
216 |
only_print=None,
|
217 |
axis=False,
|
218 |
return_image=False,
|
219 |
+
new_size=(1333, 800),
|
220 |
resize=False):
|
221 |
"""
|
222 |
Draws annotations on images including bounding boxes, keypoints, links, and text.
|
|
|
232 |
- draw_boxes (bool): Flag to draw bounding boxes.
|
233 |
- draw_text (bool): Flag to draw text annotations.
|
234 |
- draw_links (bool): Flag to draw links between annotations.
|
235 |
+
- draw_twins (bool): Flag to draw twin keypoints.
|
236 |
- write_class (bool): Flag to write class names near the annotations.
|
237 |
- write_score (bool): Flag to write scores near the annotations.
|
238 |
- write_text (bool): Flag to write OCR recognized text.
|
|
|
249 |
image_copy = image.copy()
|
250 |
scale = max(image.shape[0], image.shape[1]) / 1000
|
251 |
|
252 |
+
# Helper function to draw annotations based on provided data
|
253 |
+
def draw(data, is_prediction=False):
|
|
|
|
|
254 |
for i in range(len(data['boxes'])):
|
255 |
+
box = data['boxes'][i].tolist()
|
256 |
+
x1, y1, x2, y2 = box
|
257 |
+
if resize:
|
258 |
+
x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1], image_copy.shape[0]))[0]
|
259 |
if is_prediction:
|
|
|
|
|
|
|
|
|
260 |
score = data['scores'][i].item()
|
261 |
if score < score_threshold:
|
262 |
continue
|
|
|
|
|
|
|
263 |
if draw_boxes:
|
264 |
if only_print is not None:
|
265 |
if data['labels'][i] != list(model_dict.values()).index(only_print):
|
266 |
continue
|
267 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2 * scale))
|
268 |
if is_prediction and write_score:
|
269 |
+
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15 * scale)), cv2.FONT_HERSHEY_SIMPLEX, scale / 2, (100, 100, 255), 2)
|
270 |
|
271 |
if write_class and 'labels' in data:
|
272 |
class_id = data['labels'][i].item()
|
273 |
+
cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2 * scale)), cv2.FONT_HERSHEY_SIMPLEX, scale / 2, (255, 100, 100), 2)
|
274 |
|
275 |
if write_idx:
|
276 |
+
cv2.putText(image_copy, str(i), (int(x1) + int(15 * scale), int(y1) + int(15 * scale)), cv2.FONT_HERSHEY_SIMPLEX, 2 * scale, (0, 0, 0), 2)
|
|
|
277 |
|
278 |
# Draw keypoints if available
|
279 |
if draw_keypoints and 'keypoints' in data:
|
280 |
if is_prediction and keypoints_correction:
|
281 |
for idx, (key1, key2) in enumerate(data['keypoints']):
|
282 |
if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
283 |
+
list(model_dict.values()).index('messageFlow'),
|
284 |
+
list(model_dict.values()).index('dataAssociation')]:
|
285 |
continue
|
|
|
286 |
distance = np.linalg.norm(key1[:2] - key2[:2])
|
|
|
287 |
if distance < 5:
|
288 |
+
x_new, y_new, x, y = find_other_keypoint(idx, data['keypoints'], data['boxes'])
|
289 |
+
data['keypoints'][idx][0] = torch.tensor([x_new, y_new, 1])
|
290 |
+
data['keypoints'][idx][1] = torch.tensor([x, y, 1])
|
291 |
print("keypoint has been changed")
|
292 |
for i in range(len(data['keypoints'])):
|
293 |
kp = data['keypoints'][i]
|
294 |
for j in range(kp.shape[0]):
|
295 |
+
if is_prediction and data['labels'][i] not in [list(model_dict.values()).index('sequenceFlow'),
|
296 |
+
list(model_dict.values()).index('messageFlow'),
|
297 |
+
list(model_dict.values()).index('dataAssociation')]:
|
298 |
continue
|
299 |
if is_prediction:
|
300 |
score = data['scores'][i]
|
301 |
if score < score_threshold:
|
302 |
continue
|
303 |
+
x, y, v = np.array(kp[j])
|
304 |
if resize:
|
305 |
+
x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1], image_copy.shape[0]))[0]
|
306 |
if j == 0:
|
307 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5 * scale), (0, 0, 255), -1)
|
308 |
else:
|
309 |
+
cv2.circle(image_copy, (int(x), int(y)), int(5 * scale), (255, 0, 0), -1)
|
310 |
|
311 |
# Draw text predictions if available
|
312 |
+
if (draw_text or write_text) and text_predictions is not None:
|
313 |
for i in range(len(text_predictions[0])):
|
314 |
x1, y1, x2, y2 = text_predictions[0][i]
|
315 |
text = text_predictions[1][i]
|
316 |
if resize:
|
317 |
+
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1], image_copy.shape[0]))[0]
|
318 |
if draw_text:
|
319 |
+
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2 * scale))
|
320 |
if write_text:
|
321 |
+
cv2.putText(image_copy, text, (int(x1 + int(2 * scale)), int((y1 + y2) / 2)), cv2.FONT_HERSHEY_SIMPLEX, scale / 2, (0, 0, 0), 2)
|
322 |
+
|
323 |
def draw_with_links(full_prediction):
|
324 |
+
"""Draws links between objects based on the full prediction data."""
|
|
|
325 |
if draw_twins and full_prediction is not None:
|
326 |
+
circle_color = (0, 255, 0)
|
327 |
+
circle_radius = int(10 * scale)
|
|
|
|
|
328 |
for idx, (key1, key2) in enumerate(full_prediction['keypoints']):
|
329 |
if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
|
330 |
+
list(model_dict.values()).index('messageFlow'),
|
331 |
+
list(model_dict.values()).index('dataAssociation')]:
|
332 |
continue
|
|
|
333 |
distance = np.linalg.norm(key1[:2] - key2[:2])
|
334 |
if distance < 10:
|
335 |
+
x_new, y_new, x, y = find_other_keypoint(idx, full_prediction['keypoints'], full_prediction['boxes'])
|
336 |
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
|
337 |
+
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0, 0, 0), -1)
|
338 |
|
339 |
+
if draw_links and full_prediction is not None:
|
|
|
340 |
for i, (start_idx, end_idx) in enumerate(full_prediction['links']):
|
341 |
if start_idx is None or end_idx is None:
|
342 |
continue
|
343 |
start_box = full_prediction['boxes'][start_idx]
|
344 |
end_box = full_prediction['boxes'][end_idx]
|
345 |
current_box = full_prediction['boxes'][i]
|
|
|
346 |
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
|
347 |
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
|
348 |
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
|
349 |
+
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2 * scale))
|
350 |
+
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2 * scale))
|
|
|
351 |
|
352 |
+
i += 1
|
353 |
|
|
|
354 |
if target is not None:
|
355 |
draw(target, is_prediction=False)
|
|
|
356 |
if prediction is not None:
|
|
|
357 |
draw(prediction, is_prediction=True)
|
|
|
358 |
if full_prediction is not None:
|
359 |
draw_with_links(full_prediction)
|
360 |
|
|
|
361 |
image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
|
362 |
plt.figure(figsize=(12, 12))
|
363 |
plt.imshow(image_copy)
|
364 |
+
if not axis:
|
365 |
plt.axis('off')
|
366 |
plt.show()
|
367 |
|
|
|
382 |
closest_object_idx = None
|
383 |
best_point = None
|
384 |
min_distance = float('inf')
|
|
|
385 |
for i, box in enumerate(boxes):
|
386 |
if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
|
387 |
list(class_dict.values()).index('messageFlow'),
|
388 |
list(class_dict.values()).index('dataAssociation'),
|
|
|
389 |
list(class_dict.values()).index('lane')]:
|
390 |
continue
|
391 |
x1, y1, x2, y2 = box
|
392 |
|
393 |
+
top = ((x1 + x2) / 2, y1)
|
394 |
+
bottom = ((x1 + x2) / 2, y2)
|
395 |
+
left = (x1, (y1 + y2) / 2)
|
396 |
+
right = (x2, (y1 + y2) / 2)
|
397 |
+
points = [left, top, right, bottom]
|
398 |
|
399 |
+
pos_dict = {0: 'left', 1: 'top', 2: 'right', 3: 'bottom'}
|
400 |
|
401 |
+
for pos, point in enumerate(points):
|
|
|
402 |
distance = np.linalg.norm(keypoint[:2] - point)
|
|
|
403 |
if distance < min_distance:
|
404 |
min_distance = distance
|
405 |
closest_object_idx = i
|
|
|
407 |
|
408 |
return closest_object_idx, best_point
|
409 |
|
|
|
410 |
def error(text='There is an error in the detection'):
|
411 |
+
"""Display an error message using Streamlit."""
|
412 |
st.error(text, icon="🚨")
|
413 |
|
414 |
def warning(text='Some element are maybe not detected, verify the results, try to modify the parameters or try to add it in the method and style step.'):
|
415 |
+
"""Display a warning message using Streamlit."""
|
416 |
st.warning(text, icon="⚠️")
|