zero-shot-seg / app.py
danieaneta's picture
Update app.py
2010bf2 verified
import gradio as gr
from PIL import Image
import base64
import io
import numpy as np
from typing import List
from main import segmenter # Import the segmenter instance
def process_image(image: Image.Image, objects_text: str) -> dict:
"""Process image and return results"""
try:
# Parse objects
objects = [obj.strip() for obj in objects_text.split('.') if obj.strip()]
# Use the segmenter to process the image
results = segmenter.segment_objects(image, objects)
# Create visualization of results
# For now, just returning the original image
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Format results for response
return {
"success": True,
"message": f"Processed image with objects: {objects}",
"image": img_str,
"results": [
{
"label": r.label,
"confidence": float(r.confidence),
"bounding_box": r.bounding_box
}
for r in results
]
}
except Exception as e:
return {
"success": False,
"message": str(e),
"image": None,
"results": []
}
# Create Gradio interface with API mode enabled
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Objects (separate with dots)", placeholder="cat. dog. chair")
],
outputs=gr.JSON(label="API Response"),
title="Zero Shot Segmentation",
description="Upload an image and specify objects to detect.",
allow_flagging="never"
)
# Enable API access
demo.queue()
if __name__ == "__main__":
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_api=True
)