from typing import Dict, List, Any from transformers import pipeline,CLIPSegProcessor, CLIPSegForImageSegmentation from PIL import Image import torch import base64 import io import numpy as np class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device) self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ if "inputs" not in data: return [{"error": "Missing 'inputs' key"}] inputs_data = data["inputs"] if "image" not in inputs_data or "text" not in inputs_data: return [{"error": "Missing 'image' or 'text' key in input data"}] try: # Decode base64 image image = self.decode_image(inputs_data["image"]) prompts = inputs_data["text"] # Preprocess input inputs = self.processor( text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt" ).to("cuda") # Run inference with torch.no_grad(): outputs = self.model(**inputs) segmentation_mask = outputs.logits.cpu().numpy() segmentation_mask = segmentation_mask.squeeze() segmentation_mask = (segmentation_mask - segmentation_mask.min()) / (segmentation_mask.max() - segmentation_mask.min() + 1e-6) # Normalize to 0-1 segmentation_mask = (segmentation_mask * 255).astype(np.uint8) seg_image = Image.fromarray(segmentation_mask) seg_image_base64 = self.encode_image(seg_image) return [{"seg_image": seg_image_base64}] except Exception as e: return [{"error": str(e)}] # helper functions def decode_image(self, image_data: str) -> Image.Image: """Decodes a base64-encoded image into a PIL image.""" image_bytes = base64.b64decode(image_data) return Image.open(io.BytesIO(image_bytes)).convert("RGB") def encode_image(self, image: Image.Image) -> str: """Encodes a PIL image to a base64 string.""" buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def process_depth(self, image): print("Processing depth") print(type(image)) if isinstance(image, np.ndarray): image = Image.fromarray(image.astype("uint8")) output = self.depth_pipe(image) depth_map = np.array(output["depth"]) # Normalize to 0-255 depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-6) depth_map = (depth_map * 255).astype(np.uint8) return Image.fromarray(depth_map)