Tools / gradio_tools /tools /sam_with_clip.py
ZackBradshaw's picture
Upload folder using huggingface_hub
e67043b verified
raw
history blame
2.43 kB
from __future__ import annotations
from typing import TYPE_CHECKING
from gradio_client.client import Job
from gradio_tools.tools.gradio_tool import GradioTool
if TYPE_CHECKING:
import gradio as gr
class SAMImageSegmentationTool(GradioTool):
"""Tool for segmenting images based on natural language queries."""
def __init__(
self,
name="SAMImageSegmentation",
description=(
"A tool for identifying objects in images. "
"Input will be a five strings separated by a |: "
"the first will be the full path or URL to an image file. "
"The second will be the string query describing the objects to identify in the image. "
"The query string should be as detailed as possible. "
"The third will be the predicted_iou_threshold, if not specified by the user set it to 0.9. "
"The fourth will be the stability_score_threshold, if not specified by the user set it to 0.8. "
"The fifth is the clip_threshold, if not specified by the user set it to 0.85. "
"The output will the a path with an image file with the identified objects overlayed in the image."
),
src="curt-park/segment-anything-with-clip",
hf_token=None,
duplicate=False,
) -> None:
super().__init__(name, description, src, hf_token, duplicate)
def create_job(self, query: str) -> Job:
try:
(
image,
query,
predicted_iou_threshold,
stability_score_threshold,
clip_threshold,
) = query.split("|")
except ValueError as e:
raise ValueError(
"Not enough arguments passed to the SAMImageSegmentationTool! "
"Expected 5 (image, query, predicted_iou_threshold, stability_score_threshold, clip_threshold)"
) from e
return self.client.submit(
float(predicted_iou_threshold),
float(stability_score_threshold),
float(clip_threshold),
image,
query.strip(),
api_name="/predict",
)
def postprocess(self, output: str) -> str:
return output
def _block_input(self, gr) -> "gr.components.Component":
return gr.Textbox()
def _block_output(self, gr) -> "gr.components.Component":
return gr.Audio()