Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import os | |
import numpy as np | |
from groq import Groq | |
from transformers import AutoModel, AutoTokenizer | |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler | |
from parler_tts import ParlerTTSForConditionalGeneration | |
import soundfile as sf | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from PIL import Image | |
from decord import VideoReader, cpu | |
from tavily import TavilyClient | |
import requests | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
# Initialize models | |
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
MODEL = 'llama3-groq-70b-8192-tool-use-preview' | |
text_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True, | |
device_map="auto", torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True) | |
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1") | |
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1") | |
# Corrected image model and pipeline setup | |
base = "stabilityai/stable-diffusion-xl-base-1.0" | |
repo = "ByteDance/SDXL-Lightning" | |
ckpt = "sdxl_lightning_4step_unet.safetensors" | |
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16) | |
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) | |
image_pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda") | |
image_pipe.scheduler = EulerDiscreteScheduler.from_config(image_pipe.scheduler.config, timestep_spacing="trailing") | |
# Tavily Client | |
tavily_client = TavilyClient(api_key="tvly-YOUR_API_KEY") | |
# Voice output function | |
def play_voice_output(response): | |
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise." | |
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda') | |
prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda') | |
generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) | |
audio_arr = generation.cpu().numpy().squeeze() | |
sf.write("output.wav", audio_arr, tts_model.config.sampling_rate) | |
return "output.wav" | |
# NumPy Calculation function | |
def numpy_calculate(code: str) -> str: | |
try: | |
local_dict = {} | |
exec(code, {"np": np}, local_dict) | |
result = local_dict.get("result", "No result found") | |
return str(result) | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Function to use Langchain for RAG | |
def use_langchain_rag(file_name, file_content, query): | |
# Split the document into chunks | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
docs = text_splitter.create_documents([file_content]) | |
# Create embeddings and store in the vector database | |
embeddings = OpenAIEmbeddings() | |
db = Chroma.from_documents(docs, embeddings, persist_directory=".chroma_db") # Use a persistent directory | |
# Create a question-answering chain | |
qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=db.as_retriever()) | |
# Get the answer | |
return qa.run(query) | |
# Function to encode video | |
def encode_video(video_path): | |
MAX_NUM_FRAMES = 64 | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
sample_fps = round(vr.get_avg_fps() / 1) | |
frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
if len(frame_idx) > MAX_NUM_FRAMES: | |
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) | |
frames = vr.get_batch(frame_idx).asnumpy() | |
frames = [Image.fromarray(v.astype('uint8')) for v in frames] | |
return frames | |
# Web search function | |
def web_search(query): | |
answer = tavily_client.qna_search(query=query) | |
return answer | |
# Function to handle different input types | |
def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, websearch=False): | |
# Voice input handling | |
if audio: | |
transcription = client.audio.transcriptions.create( | |
file=(audio.name, audio.read()), | |
model="whisper-large-v3" | |
) | |
user_prompt = transcription.text | |
# If user uploaded an image and text, use MiniCPM model | |
if image: | |
image = Image.open(image).convert('RGB') | |
messages = [{"role": "user", "content": [image, user_prompt]}] | |
response = text_model.chat(image=None, msgs=messages, tokenizer=tokenizer) | |
return response | |
# Determine which tool to use | |
if doc: | |
file_content = doc.read().decode('utf-8') | |
response = use_langchain_rag(doc.name, file_content, user_prompt) | |
elif "calculate" in user_prompt.lower(): | |
response = numpy_calculate(user_prompt) | |
elif "generate" in user_prompt.lower() and ("image" in user_prompt.lower() or "picture" in user_prompt.lower()): | |
response = image_pipe(prompt=user_prompt, num_inference_steps=20, guidance_scale=7.5) | |
elif websearch: | |
response = web_search(user_prompt) | |
else: | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": user_prompt} | |
], | |
model=MODEL, | |
) | |
response = chat_completion.choices[0].message.content | |
return response | |
def main_interface(user_prompt, image=None, video=None, audio=None, doc=None, voice_only=False, websearch=False): | |
text_model.to(device='cuda', dtype=torch.bfloat16) | |
tts_model.to("cuda") | |
unet.to("cuda", torch.float16) | |
image_pipe.to("cuda") | |
response = handle_input(user_prompt, image=image, video=video, audio=audio, doc=doc, websearch=websearch) | |
if voice_only: | |
audio_file = play_voice_output(response) | |
return response, audio_file # Return both text and audio outputs | |
else: | |
return response, None # Return only the text output, no audio | |
# Gradio UI Setup | |
def create_ui(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# AI Assistant") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1) | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon") | |
video_input = gr.Video(label="Upload a video", elem_id="video-icon") | |
audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon") | |
doc_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon") | |
voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode") | |
websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode") | |
with gr.Column(scale=1): | |
submit = gr.Button("Submit") | |
output_label = gr.Label(label="Output") | |
audio_output = gr.Audio(label="Audio Output", visible=False) | |
submit.click( | |
fn=main_interface, | |
inputs=[user_prompt, image_input, video_input, audio_input, doc_input, voice_only_mode, websearch_mode], | |
outputs=[output_label, audio_output] # Expecting a string and audio file | |
) | |
# Voice-only mode UI | |
voice_only_mode.change( | |
lambda x: gr.update(visible=not x), | |
inputs=voice_only_mode, | |
outputs=[user_prompt, image_input, video_input, doc_input, websearch_mode, submit] | |
) | |
voice_only_mode.change( | |
lambda x: gr.update(visible=x), | |
inputs=voice_only_mode, | |
outputs=[audio_input] | |
) | |
return demo | |
# Launch the app | |
demo = create_ui() | |
demo.launch(inline=False) | |