Spaces:
Running
Running
import json | |
import os.path | |
import uuid | |
from engie_pipeline.models import ( | |
model_ingredient, | |
get_detr, | |
get_detr_feature_extractor, | |
get_siglip, | |
get_siglip_preprocessor, | |
) | |
from PIL import Image | |
from sacred import Experiment | |
from tqdm import tqdm | |
from engie_pipeline.utils import draw_boxes | |
pipeline_experiment = Experiment("pipeline", ingredients=[model_ingredient]) | |
def config(): | |
path = "data/" | |
output_path = "data/output" | |
conformity_threshold = .8 | |
def set_up_pipeline(output_path): | |
for labels in ["Conforme", "Non-conforme", "Hors-sujet", "Non-admissible"]: | |
if not os.path.isdir(os.path.join(output_path, labels)): | |
os.makedirs(os.path.join(output_path, labels), exist_ok=True) | |
detr = get_detr() | |
detr.eval() | |
detr_preprocessor = get_detr_feature_extractor() | |
siglip = get_siglip() | |
siglip.eval() | |
siglip_preprocessor = get_siglip_preprocessor() | |
return { | |
"detr": detr, | |
"detr_preprocessor": detr_preprocessor, | |
"siglip": siglip, | |
"siglip_preprocessor": siglip_preprocessor, | |
} | |
def pipeline( | |
detr, | |
detr_preprocessor, | |
siglip, | |
siglip_preprocessor, | |
image: Image.Image, | |
output_path: str, | |
conformity_threshold: float, | |
force_detr: bool = False, | |
): | |
filename = ( | |
".".join(os.path.basename(image.filename).split(".")[:-1]) | |
if hasattr(image, "filename") | |
else str(uuid.uuid4()) | |
) | |
if os.path.isfile(os.path.join(output_path, "Hors-sujet", filename + ".jpg")) or os.path.isfile(os.path.join(output_path, "Non-admissible", filename + ".jpg")) or os.path.isfile(os.path.join(output_path, "Conforme", filename + ".jpg")) or os.path.isfile(os.path.join(output_path, "Non-conforme", filename + ".jpg")): | |
return | |
siglip_image_input = siglip_preprocessor(image.copy()) | |
siglip_probs = siglip(siglip_image_input.unsqueeze(0)).softmax(-1) | |
conformity = None | |
if siglip_probs.argmax() == 1: | |
conformity = "Hors-sujet" | |
image.save(os.path.join(output_path, "Hors-sujet", filename + ".jpg")) | |
if not force_detr: | |
with open(os.path.join(output_path, "Hors-sujet", filename + ".json"), "w") as file: | |
json.dump( | |
{"classification_probs": siglip_probs.tolist()}, file, indent=4 | |
) | |
return | |
if siglip_probs.argmax() == 2: | |
conformity = "Non-admissible" | |
image.save(os.path.join(output_path, "Non-admissible", filename + ".jpg")) | |
if not force_detr: | |
with open(os.path.join(output_path, "Non-admissible", filename + ".json"), "w") as file: | |
json.dump( | |
{"classification_probs": siglip_probs.tolist()}, file, indent=4 | |
) | |
return | |
detr_image_input = detr_preprocessor(image.copy(), return_tensors="pt") | |
detr_output = detr(detr_image_input["pixel_values"]) | |
boxes, labels, scores = detr.process_output(detr_output, image.size) | |
if conformity is None: | |
conformity = "Conforme" if 2 in labels and scores[labels == 2].max() > conformity_threshold else "Non-conforme" | |
image = draw_boxes( | |
image=image, boxes=boxes[labels == 0], probs=scores[labels == 0], color="gray" | |
) | |
image = draw_boxes( | |
image=image, boxes=boxes[labels == 1], probs=scores[labels == 1], color="orange" | |
) | |
image = draw_boxes( | |
image=image, boxes=boxes[labels == 2], probs=scores[labels == 2], color="purple" | |
) | |
image.save(os.path.join(output_path, conformity, filename + ".jpg")) | |
with open(os.path.join(output_path, conformity, filename + ".json"), "w") as file: | |
json.dump( | |
{ | |
"classification_probs": siglip_probs.tolist(), | |
"detection_statistics": { | |
label: {"scores": scores[labels == i].tolist(), "boxes": boxes[labels == i].tolist()} | |
for i, label in enumerate( | |
["tableau", "disjoncteur", "bouton de test"] | |
) | |
}, | |
}, | |
file, | |
indent=4, | |
) | |
return siglip_probs, boxes, labels, scores, conformity | |
def run(path: str): | |
models = set_up_pipeline() | |
if os.path.isfile(path): | |
pipeline(**models, image=Image.open(path)) | |
for file in tqdm(os.listdir(path)): | |
if not file.lower().endswith(("jpg", "jpeg", "png")): | |
continue | |
pipeline(**models, image=Image.open(os.path.join(path, file)).convert('RGB'), force_detr=False) | |