Spaces:
Runtime error
Runtime error
File size: 3,881 Bytes
e67043b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from __future__ import annotations
import time
from abc import abstractmethod
from typing import Any, Tuple, Union
import gradio_client as grc
import huggingface_hub
from gradio_client.client import Job
from gradio_client.utils import QueueError
try:
import langchain as lc
LANGCHAIN_INSTALLED = True
except (ModuleNotFoundError, ImportError):
LANGCHAIN_INSTALLED = False
class GradioTool:
def __init__(
self,
name: str,
description: str,
src: str,
hf_token: str | None = None,
duplicate: bool = True,
) -> None:
self.name = name
self.description = description
if hf_token and self._is_space(src) and duplicate:
self.client = grc.Client.duplicate(from_id=src, hf_token=hf_token)
self.src = self.client.space_id
else:
self.src = src
self.client = grc.Client(self.src, hf_token=hf_token)
self.description += (
"If the output is the string QUEUE_FULL you "
"should try this action again and never stop until you "
"don't see this message."
)
self._block = None
@staticmethod
def _is_space(src: str) -> bool:
try:
huggingface_hub.get_space_runtime(src)
return True
except huggingface_hub.hf_api.RepositoryNotFoundError:
return False
@abstractmethod
def create_job(self, query: str) -> Job:
pass
@abstractmethod
def postprocess(self, output: Union[Tuple[Any], Any]) -> str:
pass
def run(self, query: str):
job = self.create_job(query)
while not job.done():
status = job.status()
print(f"\nJob Status: {str(status.code)} eta: {status.eta}")
time.sleep(30)
try:
output = self.postprocess(job.result())
except QueueError:
output = "QUEUE_FULL"
return output
# Optional gradio functionalities
def _block_input(self, gr) -> "gr.components.Component":
return gr.Textbox()
def _block_output(self, gr) -> "gr.components.Component":
return gr.Textbox()
def block_input(self) -> "gr.components.Component":
try:
import gradio as gr
GRADIO_INSTALLED = True
except (ModuleNotFoundError, ImportError):
GRADIO_INSTALLED = False
if not GRADIO_INSTALLED:
raise ModuleNotFoundError("gradio must be installed to call block_input")
else:
return self._block_input(gr)
def block_output(self) -> "gr.components.Component":
try:
import gradio as gr
GRADIO_INSTALLED = True
except (ModuleNotFoundError, ImportError):
GRADIO_INSTALLED = False
if not GRADIO_INSTALLED:
raise ModuleNotFoundError("gradio must be installed to call block_output")
else:
return self._block_output(gr)
def block(self):
"""Get the gradio Blocks of this tool for visualization."""
try:
import gradio as gr
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError("gradio must be installed to call block")
if not self._block:
self._block = gr.load(name=self.src, src="spaces")
return self._block
# Optional langchain functionalities
@property
def langchain(self) -> "langchain.agents.Tool": # type: ignore
if not LANGCHAIN_INSTALLED:
raise ModuleNotFoundError(
"langchain must be installed to access langchain tool"
)
return lc.agents.Tool( # type: ignore
name=self.name, func=self.run, description=self.description
)
def __repr__(self) -> str:
return f"GradioTool(name={self.name}, src={self.src})"
|