import streamlit as st import os import base64 import io from PIL import Image from pydub import AudioSegment import IPython import soundfile as sf import requests import pandas as pd # If you're working with DataFrames import matplotlib.figure # If you're using matplotlib figures import numpy as np # For Altair charts import altair as alt # For Bokeh charts from bokeh.models import Plot # For Plotly charts import streamlit.graphics_plotly # For Pydeck charts import streamlit.graphics_pydeck # For Vega-Lite charts import streamlit.graphics_vega_lite import time from transformers import load_tool, Agent import torch class ToolLoader: def __init__(self, tool_names): self.tools = self.load_tools(tool_names) def load_tools(self, tool_names): loaded_tools = [] for tool_name in tool_names: try: tool = load_tool(tool_name) loaded_tools.append(tool) except Exception as e: print(f"Error loading tool '{tool_name}': {e}") # Handle the error as needed, e.g., continue with other tools or take corrective action return loaded_tools class CustomHfAgent(Agent): def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None): super().__init__( chat_prompt_template=chat_prompt_template, run_prompt_template=run_prompt_template, additional_tools=additional_tools, ) self.url_endpoint = url_endpoint self.token = token self.input_params = input_params def generate_one(self, prompt, stop): headers = {"Authorization": self.token} max_new_tokens = self.input_params.get("max_new_tokens", 192) parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True} inputs = { "inputs": prompt, "parameters": parameters, } response = requests.post(self.url_endpoint, json=inputs, headers=headers) if response.status_code == 429: print("Getting rate-limited, waiting a tiny bit before trying again.") time.sleep(1) return self._generate_one(prompt) elif response.status_code != 200: raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") print(response) result = response.json()[0]["generated_text"] for stop_seq in stop: if result.endswith(stop_seq): return result[: -len(stop_seq)] return result def load_tools(tool_names): return [load_tool(tool_name) for tool_name in tool_names] # Define the tool names to load tool_names = [ "Chris4K/random-character-tool", "Chris4K/text-generation-tool", "Chris4K/sentiment-tool", "Chris4K/EmojifyTextTool", # Add other tool names as needed ] # Create tool loader instance tool_loader = ToolLoader(tool_names) # Define the callback function to handle the form submission def handle_submission(user_message, selected_tools): agent = CustomHfAgent( url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", token=os.environ['HF_token'], additional_tools=selected_tools, input_params={"max_new_tokens": 192}, ) response = agent.run(user_message) print("Agent Response\n {}".format(response)) return response st.title("Hugging Face Agent and tools") if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) tool_checkboxes = [st.checkbox(f"{tool.name} --- {tool.description} ") for tool in tool_loader.tools] with st.chat_message("assistant"): st.markdown("Hello there! How can I assist you today?") if user_message := st.chat_input("Enter message"): st.chat_message("user").markdown(user_message) st.session_state.messages.append({"role": "user", "content": user_message}) selected_tools = [tool_loader.tools[idx] for idx, checkbox in enumerate(tool_checkboxes) if checkbox] response = handle_submission(user_message, selected_tools) with st.chat_message("assistant"): if response is None: st.warning("The agent's response is None. Please try again. Generate an image of a flying horse.") elif isinstance(response, Image.Image): st.image(response) elif isinstance(response, AudioSegment): st.audio(response) elif isinstance(response, int): st.markdown(response) elif isinstance(response, str): if "emojified_text" in response: st.markdown(f"{response['emojified_text']}") else: st.markdown(response) elif isinstance(response, list): for item in response: st.markdown(item) # Assuming the list contains strings elif isinstance(response, pd.DataFrame): st.dataframe(response) elif isinstance(response, pd.Series): st.table(response.iloc[0:10]) elif isinstance(response, dict): st.json(response) elif isinstance(response, streamlit.graphics_altair.AltairChart): st.altair_chart(response) elif isinstance(response, streamlit.graphics_bokeh.BokehChart): st.bokeh_chart(response) elif isinstance(response, streamlit.graphics_graphviz.GraphvizChart): st.graphviz_chart(response) elif isinstance(response, streamlit.graphics_plotly.PlotlyChart): st.plotly_chart(response) elif isinstance(response, streamlit.graphics_pydeck.PydeckChart): st.pydeck_chart(response) elif isinstance(response, matplotlib.figure.Figure): st.pyplot(response) elif isinstance(response, streamlit.graphics_vega_lite.VegaLiteChart): st.vega_lite_chart(response) else: st.warning("Unrecognized response type. Please try again. e.g. Generate an image of a flying horse.") st.session_state.messages.append({"role": "assistant", "content": response})