import streamlit as st | |
import os | |
import requests | |
from PIL import Image | |
#from pydub.playback import Audio | |
from pydub import AudioSegment | |
# From transformers import BertModel, BertTokenizer | |
from transformers import HfAgent, load_tool | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent | |
# checkpoint = "THUDM/agentlm-7b" | |
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16) | |
# tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
# agent = LocalAgent(model, tokenizer) | |
# agent.run("Draw me a picture of rivers and lakes.") | |
# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")) | |
# Load tools | |
controlnet_transformer = load_tool("Chris4K/random-character-tool") | |
upscaler = load_tool("Chris4K/text-generation-tool") | |
tools = [controlnet_transformer, upscaler] | |
# Define the custom HfAgent class | |
class CustomHfAgent(Agent): | |
def __init__( | |
self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None | |
): | |
super().__init__( | |
chat_prompt_template=chat_prompt_template, | |
run_prompt_template=run_prompt_template, | |
additional_tools=additional_tools, | |
) | |
self.url_endpoint = url_endpoint | |
self.token = token | |
def generate_one(self, prompt, stop): | |
headers = {"Authorization": self.token} | |
inputs = { | |
"inputs": prompt, | |
"parameters": {"max_new_tokens": 192, "return_full_text": False, "stop": stop}, | |
} | |
response = requests.post(self.url_endpoint, json=inputs, headers=headers) | |
if response.status_code == 429: | |
print("Getting rate-limited, waiting a tiny bit before trying again.") | |
time.sleep(1) | |
return self._generate_one(prompt) | |
elif response.status_code != 200: | |
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") | |
print(response) | |
result = response.json()[0]["generated_text"] | |
# Inference API returns the stop sequence | |
for stop_seq in stop: | |
if result.endswith(stop_seq): | |
return result[: -len(stop_seq)] | |
return result | |
# Create the Streamlit app | |
st.title("Hugging Face Agent") | |
# Input field for the user's message | |
message = st.text_input("Enter your message:", "") | |
# Checkboxes for the tools to be used by the agent | |
tool_checkboxes = [st.checkbox(f"Use {tool.name} --- {tool.description} ") for tool in tools] | |
# Submit button | |
#submit_button = st.button("Submit") | |
# Define the callback function to handle the form submission | |
def handle_submission(): | |
# Get the user's message and the selected tools | |
#message = st.text_input("Enter your message:", "") | |
#selected_tools = [] | |
selected_tools = [tool for idx, tool in enumerate(tools) if tool_checkboxes[idx]] | |
print(selected_tools) | |
#for tool, checkbox in tool_checkboxes: | |
# if checkbox: | |
# print("checked {tool.name}") | |
# selected_tools.append(tool) | |
#selected_tools = [tool for tool, checkbox in tool_checkboxes] | |
# Initialize the agent | |
agent = CustomHfAgent(url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", token=os.environ['HF_token'], additional_tools=selected_tools) | |
# Run the agent with the user's message and selected tools | |
response = agent.run(message) | |
#response = agent.chat(message) | |
print(response) | |
# Display the agent's response | |
# if isinstance(response, str): | |
# # Display the text response | |
# print("text") | |
# st.write(response) | |
# elif isinstance(response, Image): | |
# # Display the image response | |
# # print("image") | |
# st.image(response) | |
# elif isinstance(response, Audio): | |
# print("audio") | |
# # Handle audio response (replace with your audio rendering code) | |
# st.audio(response) | |
# else: | |
# # Handle unrecognized response type | |
# print("warning") | |
# st.warning("Unrecognized response type.") | |
# Update the import statement for Audio | |
# ... | |
# Display the agent's response | |
if response is None: | |
st.warning("The agent's response is None.") | |
elif isinstance(response, Image.Image): | |
# Display the image response | |
st.image(response) | |
elif "audio" in response: | |
# Handle audio response (replace with your audio rendering code) | |
audio_data = base64.b64decode(response.split(",")[1]) | |
audio = AudioSegment.from_file(io.BytesIO(audio_data)) | |
st.audio(audio) | |
elif isinstance(response, AudioSegment): | |
# Handle audio response (replace with your audio rendering code) | |
st.audio(response) | |
elif isinstance(response, str): | |
# Display the text response | |
st.write(response) | |
elif "text" in response: | |
# Display the text response | |
st.write(response) | |
else: | |
# Handle unrecognized response type | |
st.warning("Unrecognized response type.") | |
# Display the agent's response | |
# Display the agent's response | |
#if response.startswith("Image:"): | |
# # Display the image response | |
# image_data = base64.b64decode(response.split(",")[1]) | |
# img = Image.open(io.BytesIO(image_data)) | |
# st.image(img) | |
#else: | |
# # Display the text response | |
# st.write(response) | |
# Add a button to trigger the agent to respond again | |
#st.button("Ask Again | |
#st.button("Ask Again", key="ask_again_btn") | |
# Add the callback function to the Streamlit app | |
submit_button = st.button("Submit", on_click=handle_submission) | |
#st.button("Ask Again")(handle_submission) | |
# Define a callback function to handle the button click | |
def ask_again(): | |
# Reset the message input field | |
message_input.value = "" | |
# Run the agent again with an empty message | |
agent.run("") | |
# Add the callback function to the button | |
#st.button("Ask Again").do(ask_again) |