Tools / gradio_tools /tools /gradio_tool.py
ZackBradshaw's picture
Upload folder using huggingface_hub
e67043b verified
raw
history blame
3.88 kB
from __future__ import annotations
import time
from abc import abstractmethod
from typing import Any, Tuple, Union
import gradio_client as grc
import huggingface_hub
from gradio_client.client import Job
from gradio_client.utils import QueueError
try:
import langchain as lc
LANGCHAIN_INSTALLED = True
except (ModuleNotFoundError, ImportError):
LANGCHAIN_INSTALLED = False
class GradioTool:
def __init__(
self,
name: str,
description: str,
src: str,
hf_token: str | None = None,
duplicate: bool = True,
) -> None:
self.name = name
self.description = description
if hf_token and self._is_space(src) and duplicate:
self.client = grc.Client.duplicate(from_id=src, hf_token=hf_token)
self.src = self.client.space_id
else:
self.src = src
self.client = grc.Client(self.src, hf_token=hf_token)
self.description += (
"If the output is the string QUEUE_FULL you "
"should try this action again and never stop until you "
"don't see this message."
)
self._block = None
@staticmethod
def _is_space(src: str) -> bool:
try:
huggingface_hub.get_space_runtime(src)
return True
except huggingface_hub.hf_api.RepositoryNotFoundError:
return False
@abstractmethod
def create_job(self, query: str) -> Job:
pass
@abstractmethod
def postprocess(self, output: Union[Tuple[Any], Any]) -> str:
pass
def run(self, query: str):
job = self.create_job(query)
while not job.done():
status = job.status()
print(f"\nJob Status: {str(status.code)} eta: {status.eta}")
time.sleep(30)
try:
output = self.postprocess(job.result())
except QueueError:
output = "QUEUE_FULL"
return output
# Optional gradio functionalities
def _block_input(self, gr) -> "gr.components.Component":
return gr.Textbox()
def _block_output(self, gr) -> "gr.components.Component":
return gr.Textbox()
def block_input(self) -> "gr.components.Component":
try:
import gradio as gr
GRADIO_INSTALLED = True
except (ModuleNotFoundError, ImportError):
GRADIO_INSTALLED = False
if not GRADIO_INSTALLED:
raise ModuleNotFoundError("gradio must be installed to call block_input")
else:
return self._block_input(gr)
def block_output(self) -> "gr.components.Component":
try:
import gradio as gr
GRADIO_INSTALLED = True
except (ModuleNotFoundError, ImportError):
GRADIO_INSTALLED = False
if not GRADIO_INSTALLED:
raise ModuleNotFoundError("gradio must be installed to call block_output")
else:
return self._block_output(gr)
def block(self):
"""Get the gradio Blocks of this tool for visualization."""
try:
import gradio as gr
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError("gradio must be installed to call block")
if not self._block:
self._block = gr.load(name=self.src, src="spaces")
return self._block
# Optional langchain functionalities
@property
def langchain(self) -> "langchain.agents.Tool": # type: ignore
if not LANGCHAIN_INSTALLED:
raise ModuleNotFoundError(
"langchain must be installed to access langchain tool"
)
return lc.agents.Tool( # type: ignore
name=self.name, func=self.run, description=self.description
)
def __repr__(self) -> str:
return f"GradioTool(name={self.name}, src={self.src})"