import gradio as gr from PIL import Image import base64 import io import numpy as np from typing import List from main import segmenter # Import the segmenter instance def process_image(image: Image.Image, objects_text: str) -> dict: """Process image and return results""" try: # Parse objects objects = [obj.strip() for obj in objects_text.split('.') if obj.strip()] # Use the segmenter to process the image results = segmenter.segment_objects(image, objects) # Create visualization of results # For now, just returning the original image buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() # Format results for response return { "success": True, "message": f"Processed image with objects: {objects}", "image": img_str, "results": [ { "label": r.label, "confidence": float(r.confidence), "bounding_box": r.bounding_box } for r in results ] } except Exception as e: return { "success": False, "message": str(e), "image": None, "results": [] } # Create Gradio interface with API mode enabled demo = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Objects (separate with dots)", placeholder="cat. dog. chair") ], outputs=gr.JSON(label="API Response"), title="Zero Shot Segmentation", description="Upload an image and specify objects to detect.", allow_flagging="never" ) # Enable API access demo.queue() if __name__ == "__main__": demo.launch( share=True, server_name="0.0.0.0", server_port=7860, show_api=True )