Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import time | |
from abc import abstractmethod | |
from typing import Any, Tuple, Union | |
import gradio_client as grc | |
import huggingface_hub | |
from gradio_client.client import Job | |
from gradio_client.utils import QueueError | |
try: | |
import langchain as lc | |
LANGCHAIN_INSTALLED = True | |
except (ModuleNotFoundError, ImportError): | |
LANGCHAIN_INSTALLED = False | |
class GradioTool: | |
def __init__( | |
self, | |
name: str, | |
description: str, | |
src: str, | |
hf_token: str | None = None, | |
duplicate: bool = True, | |
) -> None: | |
self.name = name | |
self.description = description | |
if hf_token and self._is_space(src) and duplicate: | |
self.client = grc.Client.duplicate(from_id=src, hf_token=hf_token) | |
self.src = self.client.space_id | |
else: | |
self.src = src | |
self.client = grc.Client(self.src, hf_token=hf_token) | |
self.description += ( | |
"If the output is the string QUEUE_FULL you " | |
"should try this action again and never stop until you " | |
"don't see this message." | |
) | |
self._block = None | |
def _is_space(src: str) -> bool: | |
try: | |
huggingface_hub.get_space_runtime(src) | |
return True | |
except huggingface_hub.hf_api.RepositoryNotFoundError: | |
return False | |
def create_job(self, query: str) -> Job: | |
pass | |
def postprocess(self, output: Union[Tuple[Any], Any]) -> str: | |
pass | |
def run(self, query: str): | |
job = self.create_job(query) | |
while not job.done(): | |
status = job.status() | |
print(f"\nJob Status: {str(status.code)} eta: {status.eta}") | |
time.sleep(30) | |
try: | |
output = self.postprocess(job.result()) | |
except QueueError: | |
output = "QUEUE_FULL" | |
return output | |
# Optional gradio functionalities | |
def _block_input(self, gr) -> "gr.components.Component": | |
return gr.Textbox() | |
def _block_output(self, gr) -> "gr.components.Component": | |
return gr.Textbox() | |
def block_input(self) -> "gr.components.Component": | |
try: | |
import gradio as gr | |
GRADIO_INSTALLED = True | |
except (ModuleNotFoundError, ImportError): | |
GRADIO_INSTALLED = False | |
if not GRADIO_INSTALLED: | |
raise ModuleNotFoundError("gradio must be installed to call block_input") | |
else: | |
return self._block_input(gr) | |
def block_output(self) -> "gr.components.Component": | |
try: | |
import gradio as gr | |
GRADIO_INSTALLED = True | |
except (ModuleNotFoundError, ImportError): | |
GRADIO_INSTALLED = False | |
if not GRADIO_INSTALLED: | |
raise ModuleNotFoundError("gradio must be installed to call block_output") | |
else: | |
return self._block_output(gr) | |
def block(self): | |
"""Get the gradio Blocks of this tool for visualization.""" | |
try: | |
import gradio as gr | |
except (ModuleNotFoundError, ImportError): | |
raise ModuleNotFoundError("gradio must be installed to call block") | |
if not self._block: | |
self._block = gr.load(name=self.src, src="spaces") | |
return self._block | |
# Optional langchain functionalities | |
def langchain(self) -> "langchain.agents.Tool": # type: ignore | |
if not LANGCHAIN_INSTALLED: | |
raise ModuleNotFoundError( | |
"langchain must be installed to access langchain tool" | |
) | |
return lc.agents.Tool( # type: ignore | |
name=self.name, func=self.run, description=self.description | |
) | |
def __repr__(self) -> str: | |
return f"GradioTool(name={self.name}, src={self.src})" | |