Chris4K's picture
Update app.py
6052994 verified
raw
history blame
6.32 kB
import streamlit as st
import os
import base64
import io
from PIL import Image
from pydub import AudioSegment
import IPython
import soundfile as sf
import requests
import pandas as pd # If you're working with DataFrames
import matplotlib.figure # If you're using matplotlib figures
# For Altair charts
import streamlit.graphics_altair
# For Bokeh charts
import streamlit.graphics_bokeh
# For Plotly charts
import streamlit.graphics_plotly
# For Pydeck charts
import streamlit.graphics_pydeck
# For Vega-Lite charts
import streamlit.graphics_vega_lite
import time
from transformers import load_tool, Agent
import torch
class ToolLoader:
def __init__(self, tool_names):
self.tools = self.load_tools(tool_names)
def load_tools(self, tool_names):
loaded_tools = []
for tool_name in tool_names:
try:
tool = load_tool(tool_name)
loaded_tools.append(tool)
except Exception as e:
print(f"Error loading tool '{tool_name}': {e}")
# Handle the error as needed, e.g., continue with other tools or take corrective action
return loaded_tools
class CustomHfAgent(Agent):
def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
self.url_endpoint = url_endpoint
self.token = token
self.input_params = input_params
def generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
max_new_tokens = self.input_params.get("max_new_tokens", 192)
parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
inputs = {
"inputs": prompt,
"parameters": parameters,
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
print(response)
result = response.json()[0]["generated_text"]
for stop_seq in stop:
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result
def load_tools(tool_names):
return [load_tool(tool_name) for tool_name in tool_names]
# Define the tool names to load
tool_names = [
"Chris4K/random-character-tool",
"Chris4K/text-generation-tool",
"Chris4K/sentiment-tool",
"Chris4K/EmojifyTextTool",
# Add other tool names as needed
]
# Create tool loader instance
tool_loader = ToolLoader(tool_names)
# Define the callback function to handle the form submission
def handle_submission(user_message, selected_tools):
agent = CustomHfAgent(
url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder",
token=os.environ['HF_token'],
additional_tools=selected_tools,
input_params={"max_new_tokens": 192},
)
response = agent.run(user_message)
print("Agent Response\n {}".format(response))
return response
st.title("Hugging Face Agent and tools")
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
tool_checkboxes = [st.checkbox(f"{tool.name} --- {tool.description} ") for tool in tool_loader.tools]
with st.chat_message("assistant"):
st.markdown("Hello there! How can I assist you today?")
if user_message := st.chat_input("Enter message"):
st.chat_message("user").markdown(user_message)
st.session_state.messages.append({"role": "user", "content": user_message})
selected_tools = [tool_loader.tools[idx] for idx, checkbox in enumerate(tool_checkboxes) if checkbox]
response = handle_submission(user_message, selected_tools)
with st.chat_message("assistant"):
if response is None:
st.warning("The agent's response is None. Please try again. Generate an image of a flying horse.")
elif isinstance(response, Image.Image):
st.image(response)
elif isinstance(response, AudioSegment):
st.audio(response)
elif isinstance(response, int):
st.markdown(response)
elif isinstance(response, str):
if "emojified_text" in response:
st.markdown(f"{response['emojified_text']}")
else:
st.markdown(response)
elif isinstance(response, list):
for item in response:
st.markdown(item) # Assuming the list contains strings
elif isinstance(response, pd.DataFrame):
st.dataframe(response)
elif isinstance(response, pd.Series):
st.table(response.iloc[0:10])
elif isinstance(response, dict):
st.json(response)
elif isinstance(response, streamlit.graphics_altair.AltairChart):
st.altair_chart(response)
elif isinstance(response, streamlit.graphics_bokeh.BokehChart):
st.bokeh_chart(response)
elif isinstance(response, streamlit.graphics_graphviz.GraphvizChart):
st.graphviz_chart(response)
elif isinstance(response, streamlit.graphics_plotly.PlotlyChart):
st.plotly_chart(response)
elif isinstance(response, streamlit.graphics_pydeck.PydeckChart):
st.pydeck_chart(response)
elif isinstance(response, matplotlib.figure.Figure):
st.pyplot(response)
elif isinstance(response, streamlit.graphics_vega_lite.VegaLiteChart):
st.vega_lite_chart(response)
else:
st.warning("Unrecognized response type. Please try again. e.g. Generate an image of a flying horse.")
st.session_state.messages.append({"role": "assistant", "content": response})