|
import streamlit as st |
|
import os |
|
import base64 |
|
import io |
|
from PIL import Image |
|
from pydub import AudioSegment |
|
import IPython |
|
import soundfile as sf |
|
import requests |
|
import pandas as pd |
|
import matplotlib.figure |
|
|
|
|
|
import streamlit.graphics_altair |
|
|
|
import streamlit.graphics_bokeh |
|
|
|
|
|
import streamlit.graphics_plotly |
|
|
|
|
|
import streamlit.graphics_pydeck |
|
|
|
import streamlit.graphics_vega_lite |
|
|
|
|
|
import time |
|
from transformers import load_tool, Agent |
|
import torch |
|
|
|
|
|
class ToolLoader: |
|
def __init__(self, tool_names): |
|
self.tools = self.load_tools(tool_names) |
|
|
|
def load_tools(self, tool_names): |
|
loaded_tools = [] |
|
for tool_name in tool_names: |
|
try: |
|
tool = load_tool(tool_name) |
|
loaded_tools.append(tool) |
|
except Exception as e: |
|
print(f"Error loading tool '{tool_name}': {e}") |
|
|
|
|
|
return loaded_tools |
|
|
|
class CustomHfAgent(Agent): |
|
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None): |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
self.url_endpoint = url_endpoint |
|
self.token = token |
|
self.input_params = input_params |
|
|
|
def generate_one(self, prompt, stop): |
|
headers = {"Authorization": self.token} |
|
max_new_tokens = self.input_params.get("max_new_tokens", 192) |
|
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True} |
|
inputs = { |
|
"inputs": prompt, |
|
"parameters": parameters, |
|
} |
|
response = requests.post(self.url_endpoint, json=inputs, headers=headers) |
|
|
|
if response.status_code == 429: |
|
print("Getting rate-limited, waiting a tiny bit before trying again.") |
|
time.sleep(1) |
|
return self._generate_one(prompt) |
|
elif response.status_code != 200: |
|
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") |
|
print(response) |
|
result = response.json()[0]["generated_text"] |
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
return result[: -len(stop_seq)] |
|
return result |
|
|
|
def load_tools(tool_names): |
|
return [load_tool(tool_name) for tool_name in tool_names] |
|
|
|
|
|
tool_names = [ |
|
"Chris4K/random-character-tool", |
|
"Chris4K/text-generation-tool", |
|
"Chris4K/sentiment-tool", |
|
"Chris4K/EmojifyTextTool", |
|
|
|
|
|
] |
|
|
|
|
|
tool_loader = ToolLoader(tool_names) |
|
|
|
|
|
def handle_submission(user_message, selected_tools): |
|
agent = CustomHfAgent( |
|
url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", |
|
token=os.environ['HF_token'], |
|
additional_tools=selected_tools, |
|
input_params={"max_new_tokens": 192}, |
|
) |
|
|
|
response = agent.run(user_message) |
|
|
|
print("Agent Response\n {}".format(response)) |
|
|
|
return response |
|
|
|
st.title("Hugging Face Agent and tools") |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
tool_checkboxes = [st.checkbox(f"{tool.name} --- {tool.description} ") for tool in tool_loader.tools] |
|
|
|
with st.chat_message("assistant"): |
|
st.markdown("Hello there! How can I assist you today?") |
|
|
|
if user_message := st.chat_input("Enter message"): |
|
st.chat_message("user").markdown(user_message) |
|
st.session_state.messages.append({"role": "user", "content": user_message}) |
|
|
|
selected_tools = [tool_loader.tools[idx] for idx, checkbox in enumerate(tool_checkboxes) if checkbox] |
|
response = handle_submission(user_message, selected_tools) |
|
|
|
with st.chat_message("assistant"): |
|
if response is None: |
|
st.warning("The agent's response is None. Please try again. Generate an image of a flying horse.") |
|
elif isinstance(response, Image.Image): |
|
st.image(response) |
|
elif isinstance(response, AudioSegment): |
|
st.audio(response) |
|
elif isinstance(response, int): |
|
st.markdown(response) |
|
elif isinstance(response, str): |
|
if "emojified_text" in response: |
|
st.markdown(f"{response['emojified_text']}") |
|
else: |
|
st.markdown(response) |
|
elif isinstance(response, list): |
|
for item in response: |
|
st.markdown(item) |
|
elif isinstance(response, pd.DataFrame): |
|
st.dataframe(response) |
|
elif isinstance(response, pd.Series): |
|
st.table(response.iloc[0:10]) |
|
elif isinstance(response, dict): |
|
st.json(response) |
|
elif isinstance(response, streamlit.graphics_altair.AltairChart): |
|
st.altair_chart(response) |
|
elif isinstance(response, streamlit.graphics_bokeh.BokehChart): |
|
st.bokeh_chart(response) |
|
elif isinstance(response, streamlit.graphics_graphviz.GraphvizChart): |
|
st.graphviz_chart(response) |
|
elif isinstance(response, streamlit.graphics_plotly.PlotlyChart): |
|
st.plotly_chart(response) |
|
elif isinstance(response, streamlit.graphics_pydeck.PydeckChart): |
|
st.pydeck_chart(response) |
|
elif isinstance(response, matplotlib.figure.Figure): |
|
st.pyplot(response) |
|
elif isinstance(response, streamlit.graphics_vega_lite.VegaLiteChart): |
|
st.vega_lite_chart(response) |
|
else: |
|
st.warning("Unrecognized response type. Please try again. e.g. Generate an image of a flying horse.") |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|