Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import os | |
import numpy as np | |
from groq import Groq | |
import spaces | |
from transformers import AutoModel, AutoTokenizer | |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler | |
from parler_tts import ParlerTTSForConditionalGeneration | |
import soundfile as sf | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from langchain import LLMChain, PromptTemplate | |
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent | |
from langchain.llms import OpenAI | |
from PIL import Image | |
from decord import VideoReader, cpu | |
from tavily import TavilyClient | |
import requests | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
# Initialize models and clients | |
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
MODEL = 'llama3-groq-70b-8192-tool-use-preview' | |
vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True, | |
device_map="auto", torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True) | |
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1") | |
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1") | |
# Image generation model | |
base = "stabilityai/stable-diffusion-xl-base-1.0" | |
repo = "ByteDance/SDXL-Lightning" | |
ckpt = "sdxl_lightning_4step_unet.safetensors" | |
unet = UNet2DConditionModel.from_config(base, subfolder="unet") | |
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt))) | |
image_pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16") | |
image_pipe.scheduler = EulerDiscreteScheduler.from_config(image_pipe.scheduler.config, timestep_spacing="trailing") | |
# Tavily Client for web search | |
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API")) | |
# Function to play voice output | |
def play_voice_output(response): | |
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise." | |
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda') | |
prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda') | |
generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) | |
audio_arr = generation.cpu().numpy().squeeze() | |
sf.write("output.wav", audio_arr, tts_model.config.sampling_rate) | |
return "output.wav" | |
# NumPy Code Calculator Tool | |
def numpy_code_calculator(query): | |
try: | |
llm_response = client.chat.completions.create( | |
model=MODEL, | |
messages=[ | |
{"role": "user", "content": f"Write NumPy code to: {query}"} | |
] | |
) | |
code = llm_response.choices[0].message.content | |
print(f"Generated NumPy code:\n{code}") | |
# Execute the code in a safe environment | |
local_dict = {"np": np} | |
exec(code, local_dict) | |
result = local_dict.get("result", "No result found") | |
return str(result) | |
except Exception as e: | |
return f"Error: {e}" | |
# Web Search Tool | |
def web_search(query): | |
answer = tavily_client.qna_search(query=query) | |
return answer | |
# Image Generation Tool | |
def image_generation(query): | |
image = image_pipe(prompt=query, num_inference_steps=20, guidance_scale=7.5).images[0] | |
image.save("output.jpg") | |
return "output.jpg" | |
# Document Question Answering Tool | |
def doc_question_answering(query, file_path): | |
with open(file_path, 'r') as f: | |
file_content = f.read() | |
# Split the document into smaller chunks | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
docs = text_splitter.create_documents([file_content]) | |
# Create embeddings using the groq model | |
embeddings = OpenAIEmbeddings() # If you're using a custom embeddings model, replace this line with the corresponding embeddings model for groq | |
# Set up the Chroma database for document retrieval | |
db = Chroma.from_documents(docs, embeddings, persist_directory=".chroma_db") | |
# Create a custom function to use groq for the question-answering step | |
def groq_llm(query): | |
response = client.chat.completions.create( | |
model=MODEL, | |
messages=[{"role": "user", "content": query}] | |
) | |
return response.choices[0].message.content | |
# Set up the RetrievalQA chain using the custom groq LLM function | |
qa = RetrievalQA.from_chain_type(llm=groq_llm, chain_type="stuff", retriever=db.as_retriever()) | |
# Run the QA process with the groq model | |
return qa.run(query) | |
# Function to handle different input types and choose the right tool | |
def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, websearch=False): | |
if audio: | |
if isinstance(audio, str): | |
audio = open(audio, "rb") | |
transcription = client.audio.transcriptions.create( | |
file=(audio.name, audio.read()), | |
model="whisper-large-v3" | |
) | |
user_prompt = transcription.text | |
tools = [ | |
Tool( | |
name="Numpy Code Calculator", | |
func=numpy_code_calculator, | |
description="Useful for when you need to perform mathematical calculations using NumPy. Provide the calculation you want to perform.", | |
), | |
Tool( | |
name="Web Search", | |
func=web_search, | |
description="Useful for when you need to find information from the real world.", | |
), | |
Tool( | |
name="Image Generation", | |
func=image_generation, | |
description="Useful for when you need to generate an image based on a description.", | |
), | |
] | |
if doc: | |
tools.append( | |
Tool( | |
name="Document Question Answering", | |
func=lambda query: doc_question_answering(query, doc.name), | |
description="Useful for when you need to answer questions about the uploaded document.", | |
) | |
) | |
# Add this new code block: | |
prefix = """You are an AI assistant. You have access to the following tools:""" | |
suffix = """Begin!" | |
{chat_history} | |
Human: {input} | |
AI: I will do my best to assist you. Let me think about this step-by-step:""" | |
prompt = ZeroShotAgent.create_prompt( | |
tools, | |
prefix=prefix, | |
suffix=suffix, | |
input_variables=["input", "chat_history"] | |
) | |
llm = Groq(model=MODEL) | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True) | |
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True) | |
if image: | |
image = Image.open(image).convert('RGB') | |
messages = [{"role": "user", "content": [image, user_prompt]}] | |
response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer) | |
return response | |
if websearch: | |
response = agent_executor.run(f"{user_prompt} Use the Web Search tool if necessary.") | |
else: | |
response = agent_executor.run(user_prompt) | |
return response | |
# Gradio UI Setup | |
def create_ui(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# AI Assistant") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1) | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon") | |
audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon") | |
doc_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon") | |
voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode") | |
websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode") | |
with gr.Column(scale=1): | |
submit = gr.Button("Submit") | |
output_label = gr.Label(label="Output") | |
audio_output = gr.Audio(label="Audio Output", visible=False) | |
submit.click( | |
fn=main_interface, | |
inputs=[user_prompt, image_input, audio_input, doc_input, voice_only_mode, websearch_mode], | |
outputs=[output_label, audio_output] | |
) | |
voice_only_mode.change( | |
lambda x: gr.update(visible=not x), | |
inputs=voice_only_mode, | |
outputs=[user_prompt, image_input, doc_input, websearch_mode, submit] | |
) | |
voice_only_mode.change( | |
lambda x: gr.update(visible=x), | |
inputs=voice_only_mode, | |
outputs=[audio_input] | |
) | |
return demo | |
# Main interface function | |
def main_interface(user_prompt, image=None, audio=None, doc=None, voice_only=False, websearch=False): | |
vqa_model.to(device='cuda', dtype=torch.bfloat16) | |
tts_model.to("cuda") | |
unet.to("cuda") | |
image_pipe.to("cuda") | |
response = handle_input(user_prompt, image=image, audio=audio, doc=doc, websearch=websearch) | |
if voice_only: | |
audio_output = play_voice_output(response) | |
return "Response generated.", audio_output | |
else: | |
return response, None | |
# Launch the UI | |
demo = create_ui() | |
demo.launch() | |