from typing import Dict, List, Any from transformers import pipeline,CLIPSegProcessor, CLIPSegForImageSegmentation from PIL import Image import torch import base64 import io import numpy as np class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path) self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ if "image" not in data or "text" not in data: return [{"error": "Missing 'image' or 'text' key in input data"}] try: # Decode base64 image image = self.decode_image(data["image"]) prompts = data["text"].split(",") # Preprocess input inputs = self.processor( text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt" ).to("cuda") # Run inference with torch.no_grad(): outputs = self.model(**inputs) segmentation_mask = outputs.logits.cpu().numpy() segmentation_mask = segmentation_mask.squeeze() segmentation_mask = (segmentation_mask - segmentation_mask.min()) / (segmentation_mask.max() - segmentation_mask.min() + 1e-6) # Normalize to 0-1 segmentation_mask = (segmentation_mask * 255).astype(np.uint8) seg_image = Image.fromarray(segmentation_mask) return [{"seg_image": seg_image}] except Exception as e: return [{"error": str(e)}] # helper functions def decode_image(self, image_data: str) -> Image.Image: """Decodes a base64-encoded image into a PIL image.""" image_bytes = base64.b64decode(image_data) return Image.open(io.BytesIO(image_bytes)).convert("RGB") def process_depth(self, image): print("Processing depth") print(type(image)) if isinstance(image, np.ndarray): image = Image.fromarray(image.astype("uint8")) output = self.depth_pipe(image) depth_map = np.array(output["depth"]) # Normalize to 0-255 depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-6) depth_map = (depth_map * 255).astype(np.uint8) return Image.fromarray(depth_map)