Chris4K's picture
Update app.py
bd4ac0b verified
raw
history blame
5.47 kB
import streamlit as st
import os
import requests
from PIL import Image
#from pydub.playback import Audio
from pydub import AudioSegment
import IPython
import soundfile as sf
def play_audio(audio):
sf.write("speech_converted.wav", audio.numpy(), samplerate=16000)
return IPython.display.Audio("speech_converted.wav")
# From transformers import BertModel, BertTokenizer
from transformers import HfAgent, load_tool
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent
# checkpoint = "THUDM/agentlm-7b"
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# agent = LocalAgent(model, tokenizer)
# agent.run("Draw me a picture of rivers and lakes.")
# print(agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!"))
# Load tools
random_character_tool = load_tool("Chris4K/random-character-tool")
text_generation_tool = load_tool("Chris4K/text-generation-tool")
tools = [random_character_tool, text_generation_tool]
# Define the custom HfAgent class
class CustomHfAgent(Agent):
def __init__(
self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None
):
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
self.url_endpoint = url_endpoint
self.token = token
self.input_params = input_params
def generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
# Use the value from input_params or a default value if not provided
max_new_tokens = self.input_params.get("max_new_tokens", 192)
inputs = {
"inputs": prompt,
# Here the max_new_token varies from default 200 which leads to an error
"parameters": {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop},
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
print(response)
result = response.json()[0]["generated_text"]
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result
st.title("Hugging Face Agent and tools")
# Input field for the user's message
message = st.text_input("Enter your message:", "")
# Checkboxes for the tools to be used by the agent
tool_checkboxes = [st.checkbox(f"Use {tool.name} --- {tool.description} ") for tool in tools]
# Submit button
#submit_button = st.button("Submit")
# Initialize the agent
agent = CustomHfAgent(
url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder",
token=os.environ['HF_token'],
additional_tools=selected_tools,
input_params={"max_new_tokens": 192}, # Set the desired value
)
# Define the callback function to handle the form submission
def handle_submission():
# Get the user's message and the selected tools
#message = st.text_input("Enter your message:", "")
#selected_tools = []
selected_tools = [tool for idx, tool in enumerate(tools) if tool_checkboxes[idx]]
print(selected_tools)
agent.tools = selected_tools
# Run the agent with the user's message and selected tools
#response = agent.run(message)
response = agent.chat(message)
print("Response " + response)
# Display the agent's response
if response is None:
st.warning("The agent's response is None. Please try again. For Example: Generate an image of a boat in the water")
elif isinstance(response, Image.Image):
# Display the image response
st.image(response)
elif "audio" in response:
# Handle audio response (replace with your audio rendering code)
audio_data = base64.b64decode(response.split(",")[1])
audio = AudioSegment.from_file(io.BytesIO(audio_data))
st.audio(audio)
elif isinstance(response, AudioSegment):
# Handle audio response (replace with your audio rendering code)
st.audio(response)
elif isinstance(response, str):
# Display the text response
st.write(response)
elif "text" in response:
# Display the text response
st.write(response)
else:
# Handle unrecognized response type
st.warning("Unrecognized response type. Please try again. For Example: Generate an image of a boat in the water")
# Add the callback function to the Streamlit app
submit_button = st.button("Submit", on_click=handle_submission)
# Define a callback function to handle the button click
def ask_again():
# Reset the message input field
message_input.value = ""
# Run the agent again with an empty message
agent.run("")
# Add the callback function to the button
ask_again = st.button("Ask again", on_click=ask_again)