Spaces:
Running
Running
import albumentations as A | |
import base64 | |
import cv2 | |
import gradio as gr | |
import inspect | |
import io | |
import numpy as np | |
import os | |
from dataclasses import dataclass | |
from loguru import logger | |
from copy import deepcopy | |
from functools import wraps | |
from PIL import Image, ImageDraw | |
from typing import get_type_hints, Optional | |
from pydantic_core._pydantic_core import ValidationError | |
# from mixpanel import Mixpanel | |
from utils import is_not_supported_transform | |
# Some constants for Albumentations | |
PositionType = A.PadIfNeeded.PositionType | |
# MIXPANEL_TOKEN = os.getenv("MIXPANEL_TOKEN") | |
# mp = Mixpanel(MIXPANEL_TOKEN) | |
HEADER = f""" | |
<div align="center"> | |
<p> | |
<img src="https://avatars.githubusercontent.com/u/57894582?s=200&v=4" alt="A" width="50" height="50" style="display:inline;"> | |
<span style="font-size: 30px; vertical-align: bottom;"> lbumentations Demo ({A.__version__})</span> | |
</p> | |
<p style="margin-top: -15px;"> | |
<a href="https://albumentations.ai/docs/" target="_blank" style="color: grey;">Documentation</a> | |
| |
<a href="https://github.com/albumentations-team/albumentations" target="_blank" style="color: grey;">GitHub Repository</a> | |
</p> | |
</div> | |
""" | |
DEFAULT_TRANSFORM = "Rotate" | |
NO_OPERATION_TRANFORM = "NoOp" | |
DEFAULT_IMAGE_PATH = "images/doctor.webp" | |
DEFAULT_IMAGE = np.array(Image.open(DEFAULT_IMAGE_PATH)) | |
DEFAULT_IMAGE_HEIGHT = DEFAULT_IMAGE.shape[0] | |
DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE.shape[1] | |
DEFAULT_BOXES = [ | |
[265, 121, 326, 177], # Mask | |
[192, 169, 401, 395], # Coverall | |
] | |
mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]] | |
pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]] | |
arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]] | |
DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints | |
BASE64_DEFAULT_MASKS = [ | |
{ | |
"label": "Coverall", | |
# light green color | |
"color": (144, 238, 144), | |
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==", | |
}, | |
{ | |
"label": "Mask", | |
# light blue color | |
"color": (173, 216, 230), | |
"mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC", | |
}, | |
] | |
# Get all the transforms from the albumentations library | |
transforms_map = { | |
name: cls | |
for name, cls in vars(A).items() | |
if ( | |
inspect.isclass(cls) | |
and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform)) | |
and not is_not_supported_transform(cls) | |
) | |
} | |
transforms_map.pop("DualTransform", None) | |
transforms_map.pop("ImageOnlyTransform", None) | |
transforms_map.pop("ReferenceBasedTransform", None) | |
transforms_map.pop("ToFloat", None) | |
transforms_map.pop("Normalize", None) | |
transforms_keys = list(sorted(transforms_map.keys())) | |
# Decode the masks | |
for mask in BASE64_DEFAULT_MASKS: | |
mask["mask"] = np.array( | |
Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L") | |
) | |
class RequestParams: | |
user_ip: str | |
transform_name: Optional[str] | |
def track_event(event_name, user_id="unknown", properties=None): | |
if properties is None: | |
properties = {} | |
#mp.track(user_id, event_name, properties) | |
logger.info(f"Event tracked: {event_name} - {properties}") | |
def get_params(request: gr.Request) -> RequestParams: | |
"""Parse input request parameters.""" | |
ip = request.client.host | |
transform_name = request.query_params.get("transform", None) | |
params = RequestParams(user_ip=ip, transform_name=transform_name) | |
track_event("app_opened", user_id=params.user_ip, properties={"transform_name": params.transform_name}) | |
return params | |
def run_with_retry(compose): | |
def wrapper(*args, **kwargs): | |
processors = deepcopy(compose.processors) | |
for _ in range(4): | |
try: | |
result = compose(*args, **kwargs) | |
break | |
except NotImplementedError as e: | |
print(f"Caught NotImplementedError: {e}") | |
if "bbox" in str(e): | |
kwargs.pop("bboxes", None) | |
kwargs.pop("category_id", None) | |
compose.processors.pop("bboxes") | |
if "keypoint" in str(e): | |
kwargs.pop("keypoints", None) | |
compose.processors.pop("keypoints") | |
if "mask" in str(e): | |
kwargs.pop("mask", None) | |
except (ValueError, ValidationError) as e: | |
raise gr.Error(str(e)) | |
except Exception as e: | |
compose.processors = processors | |
raise e | |
compose.processors = processors | |
return result | |
return wrapper | |
def draw_boxes(image, boxes, color=(255, 0, 0), thickness=1) -> np.ndarray: | |
"""Draw boxes with PIL.""" | |
pil_image = Image.fromarray(image) | |
draw = ImageDraw.Draw(pil_image) | |
for box in boxes: | |
x_min, y_min, x_max, y_max = box | |
draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness) | |
return np.array(pil_image) | |
def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2): | |
"""Draw keypoints with PIL.""" | |
pil_image = Image.fromarray(image) | |
draw = ImageDraw.Draw(pil_image) | |
for keypoint in keypoints: | |
x, y = keypoint | |
draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color) | |
return np.array(pil_image) | |
def get_rgb_mask(masks): | |
"""Get the RGB mask from the binary mask.""" | |
rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8) | |
for data in masks: | |
mask = data["mask"] | |
rgb_mask[mask > 0] = np.array(data["color"]) | |
return rgb_mask | |
def draw_mask(image, mask): | |
"""Draw the mask on the image.""" | |
image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0) | |
return image_with_mask | |
def draw_not_implemented_image(image: np.ndarray, annotation_type: str): | |
"""Draw the image with a text. In the middle.""" | |
pil_image = Image.fromarray(image) | |
draw = ImageDraw.Draw(pil_image) | |
# align in the centerm, and make bigger font | |
text = f'Transform NOT working with "{annotation_type.upper()}" annotations.' | |
length = draw.textlength(text) | |
draw.text( | |
(DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2), | |
text, | |
fill=(255, 0, 0), | |
align="center", | |
) | |
return np.array(pil_image) | |
def get_formatted_signature(function_or_class, indentation=4): | |
signature = inspect.signature(function_or_class) | |
type_hints = get_type_hints(function_or_class) | |
args = [] | |
for param in signature.parameters.values(): | |
if param.name == "p": | |
str_param = "p=1.0," | |
elif param.default == inspect.Parameter.empty: | |
if "height" in param.name or "width" in param.name: | |
str_param = f"{param.name}=300," | |
else: | |
str_param = f"{param.name}=," | |
else: | |
if isinstance(param.default, str): | |
str_param = f'{param.name}="{param.default}",' | |
else: | |
str_param = f"{param.name}={param.default}," | |
annotation = type_hints.get(param.name, param.annotation) | |
if isinstance(param.annotation, type): | |
str_param += f" # {param.annotation.__name__}" | |
else: | |
str_annotation = str(annotation).replace("typing.", "") | |
str_param += f" # {str_annotation}" | |
str_param = "\n" + " " * indentation + str_param | |
args.append(str_param) | |
result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")" | |
return result | |
def get_formatted_transform(transform_name): | |
track_event("transform_selected", properties={"transform_name": transform_name}) | |
transform = transforms_map[transform_name] | |
return f"A.{transform.__name__}{get_formatted_signature(transform)}" | |
def get_formatted_transform_docs(transform_name): | |
transform = transforms_map[transform_name] | |
return transform.__doc__.strip("\n") | |
def update_augmented_images(image, code): | |
if "=," in code: | |
raise gr.Error("You have to fill in parameters to apply transform! See 'Code' section!") | |
try: | |
augmentation = eval(code) | |
except ValidationError as e: | |
raise gr.Error(str(e)) | |
except Exception as e: | |
logger.info(code) | |
logger.error(e) | |
raise e | |
track_event("transform_applied", properties={"transform_name": augmentation.__class__.__name__, "code": code}) | |
compose = A.Compose( | |
[augmentation], | |
bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]), | |
keypoint_params=A.KeypointParams(format="xy"), | |
) | |
compose = run_with_retry(compose) # to prevent NotImplementedError | |
keypoints = DEFAULT_KEYPOINTS | |
bboxes = DEFAULT_BOXES | |
mask = get_rgb_mask(BASE64_DEFAULT_MASKS) | |
augmented = compose( | |
image=image, | |
mask=mask, | |
keypoints=keypoints, | |
bboxes=bboxes, | |
category_id=range(len(bboxes)), | |
) | |
image = augmented["image"] | |
mask = augmented.get("mask", None) | |
bboxes = augmented.get("bboxes", None) | |
keypoints = augmented.get("keypoints", None) | |
# Draw the augmented images (or replace by placeholder if not implemented) | |
if mask is not None: | |
image_with_mask = draw_mask(image.copy(), mask) | |
else: | |
image_with_mask = draw_not_implemented_image(image.copy(), "mask") | |
if bboxes is not None: | |
image_with_bboxes = draw_boxes(image.copy(), bboxes) | |
else: | |
image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes") | |
if keypoints is not None: | |
image_with_keypoints = draw_keypoints(image.copy(), keypoints) | |
else: | |
image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints") | |
return [ | |
(image_with_mask, "Mask"), | |
(image_with_bboxes, "Boxes"), | |
(image_with_keypoints, "Keypoints"), | |
] | |
def update_image_info(image): | |
h, w = image.shape[:2] | |
dtype = image.dtype | |
max_, min_ = image.max(), image.min() | |
return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}" | |
def update_code_and_docs(select): | |
code = get_formatted_transform(select) | |
docs = get_formatted_transform_docs(select) | |
return code, docs | |
def update_code_and_docs_on_start(url_params: gr.Request): | |
params = get_params(url_params) | |
if params.transform_name is not None and params.transform_name not in transforms_map: | |
gr.Warning(f"Sorry, `{params.transform_name}` transform is not supported at the moment :(") | |
transform_name = NO_OPERATION_TRANFORM | |
elif params.transform_name in transforms_map: | |
transform_name = params.transform_name | |
else: | |
transform_name = DEFAULT_TRANSFORM | |
return gr.update(value=transform_name) | |
with gr.Blocks() as demo: | |
gr.Markdown(HEADER) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
# gr.Markdown( | |
# (" " * 4) + \ | |
# "If a component is loading on start, please, try to refresh the page a few times. [Working on fix...]" | |
# ) | |
select = gr.Dropdown( | |
label="Select a transformation", | |
choices=transforms_keys, | |
value=DEFAULT_TRANSFORM, | |
type="value", | |
interactive=True, | |
) | |
with gr.Accordion("Documentation (click to expand)", open=False): | |
docs = gr.TextArea( | |
get_formatted_transform_docs(DEFAULT_TRANSFORM), | |
show_label=False, | |
interactive=False, | |
) | |
code = gr.Code( | |
label="Code", | |
language="python", | |
value=get_formatted_transform(DEFAULT_TRANSFORM), | |
interactive=True, | |
lines=5, | |
) | |
info = gr.TextArea( | |
value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)", | |
show_label=False, | |
lines=1, | |
max_lines=1, | |
) | |
button = gr.Button("Apply!") | |
image = gr.Image( | |
value=DEFAULT_IMAGE_PATH, | |
type="numpy", | |
height=500, | |
width=300, | |
sources=[], | |
) | |
with gr.Row(): | |
augmented_image = gr.Gallery( | |
value=update_augmented_images(DEFAULT_IMAGE, "A.NoOp()"), | |
rows=1, | |
columns=3, | |
show_label=False, | |
) | |
select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs]) | |
button.click( | |
fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image] | |
) | |
demo.load( | |
update_code_and_docs_on_start, inputs=None, outputs=[select], queue=False | |
) | |
if __name__ == "__main__": | |
demo.launch() | |