Chris4K's picture
Update app.py
4c6c6ab verified
raw
history blame
6.07 kB
import streamlit as st
import os
import requests
from PIL import Image
#from pydub.playback import Audio
from pydub import AudioSegment
# From transformers import BertModel, BertTokenizer
from transformers import HfAgent, load_tool
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent
# checkpoint = "THUDM/agentlm-7b"
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# agent = LocalAgent(model, tokenizer)
# agent.run("Draw me a picture of rivers and lakes.")
# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!"))
# Load tools
controlnet_transformer = load_tool("Chris4K/random-character-tool")
upscaler = load_tool("Chris4K/text-generation-tool")
tools = [controlnet_transformer, upscaler]
# Define the custom HfAgent class
class CustomHfAgent(Agent):
def __init__(
self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None
):
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
self.url_endpoint = url_endpoint
self.token = token
def generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
inputs = {
"inputs": prompt,
"parameters": {"max_new_tokens": 192, "return_full_text": False, "stop": stop},
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
print(response)
result = response.json()[0]["generated_text"]
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result
# Create the Streamlit app
st.title("Hugging Face Agent")
# Input field for the user's message
message = st.text_input("Enter your message:", "")
# Checkboxes for the tools to be used by the agent
tool_checkboxes = [st.checkbox(f"Use {tool.name} --- {tool.description} ") for tool in tools]
# Submit button
#submit_button = st.button("Submit")
# Define the callback function to handle the form submission
def handle_submission():
# Get the user's message and the selected tools
#message = st.text_input("Enter your message:", "")
#selected_tools = []
selected_tools = [tool for idx, tool in enumerate(tools) if tool_checkboxes[idx]]
print(selected_tools)
#for tool, checkbox in tool_checkboxes:
# if checkbox:
# print("checked {tool.name}")
# selected_tools.append(tool)
#selected_tools = [tool for tool, checkbox in tool_checkboxes]
# Initialize the agent
agent = CustomHfAgent(url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", token=os.environ['HF_token'], additional_tools=selected_tools)
# Run the agent with the user's message and selected tools
response = agent.run(message)
#response = agent.chat(message)
print(response)
# Display the agent's response
# if isinstance(response, str):
# # Display the text response
# print("text")
# st.write(response)
# elif isinstance(response, Image):
# # Display the image response
# # print("image")
# st.image(response)
# elif isinstance(response, Audio):
# print("audio")
# # Handle audio response (replace with your audio rendering code)
# st.audio(response)
# else:
# # Handle unrecognized response type
# print("warning")
# st.warning("Unrecognized response type.")
# Update the import statement for Audio
# ...
# Display the agent's response
if response is None:
st.warning("The agent's response is None.")
elif isinstance(response, Image.Image):
# Display the image response
st.image(response)
elif "audio" in response:
# Handle audio response (replace with your audio rendering code)
audio_data = base64.b64decode(response.split(",")[1])
audio = AudioSegment.from_file(io.BytesIO(audio_data))
st.audio(audio)
elif isinstance(response, AudioSegment):
# Handle audio response (replace with your audio rendering code)
st.audio(response)
elif isinstance(response, str):
# Display the text response
st.write(response)
elif "text" in response:
# Display the text response
st.write(response)
else:
# Handle unrecognized response type
st.warning("Unrecognized response type.")
# Display the agent's response
# Display the agent's response
#if response.startswith("Image:"):
# # Display the image response
# image_data = base64.b64decode(response.split(",")[1])
# img = Image.open(io.BytesIO(image_data))
# st.image(img)
#else:
# # Display the text response
# st.write(response)
# Add a button to trigger the agent to respond again
#st.button("Ask Again
#st.button("Ask Again", key="ask_again_btn")
# Add the callback function to the Streamlit app
submit_button = st.button("Submit", on_click=handle_submission)
#st.button("Ask Again")(handle_submission)
# Define a callback function to handle the button click
def ask_again():
# Reset the message input field
message_input.value = ""
# Run the agent again with an empty message
agent.run("")
# Add the callback function to the button
#st.button("Ask Again").do(ask_again)