|
import streamlit as st |
|
import os |
|
import requests |
|
|
|
from PIL import Image |
|
|
|
from pydub import AudioSegment |
|
|
|
import IPython |
|
import soundfile as sf |
|
|
|
def play_audio(audio): |
|
sf.write("speech_converted.wav", audio.numpy(), samplerate=16000) |
|
return IPython.display.Audio("speech_converted.wav") |
|
|
|
|
|
|
|
from transformers import HfAgent, load_tool |
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Agent, LocalAgent |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
random_character_tool = load_tool("Chris4K/random-character-tool") |
|
text_generation_tool = load_tool("Chris4K/text-generation-tool") |
|
|
|
|
|
tools = [random_character_tool, text_generation_tool] |
|
|
|
|
|
class CustomHfAgent(Agent): |
|
def __init__( |
|
self, url_endpoint, token=os.environ['HF_token'], chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None |
|
): |
|
super().__init__( |
|
chat_prompt_template=chat_prompt_template, |
|
run_prompt_template=run_prompt_template, |
|
additional_tools=additional_tools, |
|
) |
|
self.url_endpoint = url_endpoint |
|
self.token = token |
|
self.input_params = input_params |
|
|
|
def generate_one(self, prompt, stop): |
|
headers = {"Authorization": self.token} |
|
|
|
max_new_tokens = self.input_params.get("max_new_tokens", 192) |
|
inputs = { |
|
"inputs": prompt, |
|
|
|
"parameters": {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop}, |
|
} |
|
response = requests.post(self.url_endpoint, json=inputs, headers=headers) |
|
if response.status_code == 429: |
|
print("Getting rate-limited, waiting a tiny bit before trying again.") |
|
time.sleep(1) |
|
return self._generate_one(prompt) |
|
elif response.status_code != 200: |
|
raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") |
|
print(response) |
|
result = response.json()[0]["generated_text"] |
|
|
|
for stop_seq in stop: |
|
if result.endswith(stop_seq): |
|
return result[: -len(stop_seq)] |
|
return result |
|
|
|
st.title("Hugging Face Agent and tools") |
|
|
|
|
|
message = st.text_input("Enter your message:", "") |
|
|
|
|
|
|
|
|
|
tool_checkboxes = [st.checkbox(f"Use {tool.name} --- {tool.description} ") for tool in tools] |
|
|
|
|
|
|
|
|
|
|
|
agent = CustomHfAgent( |
|
url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", |
|
token=os.environ['HF_token'], |
|
additional_tools=selected_tools, |
|
input_params={"max_new_tokens": 192}, |
|
) |
|
|
|
|
|
|
|
def handle_submission(): |
|
|
|
|
|
|
|
|
|
selected_tools = [tool for idx, tool in enumerate(tools) if tool_checkboxes[idx]] |
|
|
|
print(selected_tools) |
|
|
|
agent.tools = selected_tools |
|
|
|
|
|
|
|
|
|
response = agent.chat(message) |
|
|
|
print("Response " + response) |
|
|
|
|
|
if response is None: |
|
st.warning("The agent's response is None. Please try again. For Example: Generate an image of a boat in the water") |
|
elif isinstance(response, Image.Image): |
|
|
|
st.image(response) |
|
elif "audio" in response: |
|
|
|
audio_data = base64.b64decode(response.split(",")[1]) |
|
audio = AudioSegment.from_file(io.BytesIO(audio_data)) |
|
st.audio(audio) |
|
elif isinstance(response, AudioSegment): |
|
|
|
st.audio(response) |
|
elif isinstance(response, str): |
|
|
|
st.write(response) |
|
elif "text" in response: |
|
|
|
st.write(response) |
|
else: |
|
|
|
st.warning("Unrecognized response type. Please try again. For Example: Generate an image of a boat in the water") |
|
|
|
|
|
|
|
|
|
submit_button = st.button("Submit", on_click=handle_submission) |
|
|
|
|
|
def ask_again(): |
|
|
|
message_input.value = "" |
|
|
|
|
|
agent.run("") |
|
|
|
|
|
ask_again = st.button("Ask again", on_click=ask_again) |