PG-Research-ai / app.py
Dr-Newtons's picture
Update app.py
95f6225 verified
import streamlit as st
import os
import time
import json
import re
from typing import List, Literal, TypedDict
from transformers import AutoTokenizer
from tools.tools import toolsInfo
from gradio_client import Client
import constants as C
import utils as U
from openai import OpenAI
import anthropic
from groq import Groq
from dotenv import load_dotenv
load_dotenv()
ModelType = Literal["GPT4", "CLAUDE", "LLAMA"]
ModelConfig = TypedDict("ModelConfig", {
"client": OpenAI | Groq | anthropic.Anthropic,
"model": str,
"max_context": int,
"tokenizer": AutoTokenizer
})
modelType: ModelType = os.environ.get("MODEL_TYPE") or "LLAMA"
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
"GPT4": {
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
"model": "gpt-4o-mini",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
"CLAUDE": {
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
"model": "claude-3-5-sonnet-20240620",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
},
"LLAMA": {
"client": Groq(api_key=os.environ.get("GROQ_API_KEY")),
"model": "llama-3.1-70b-versatile",
# "model": "llama-3.2-90b-text-preview",
"tools_model": "llama3-groq-70b-8192-tool-use-preview",
"max_context": 12800, # intentionally reduced to 1/10th
"tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
}
}
client = MODEL_CONFIG[modelType]["client"]
MODEL = MODEL_CONFIG[modelType]["model"]
TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
isClaudeModel = modelType == "CLAUDE"
def __countTokens(text):
text = str(text)
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
st.set_page_config(
page_title="Dr Newtons PG research Ai",
page_icon=C.AI_ICON,
)
st.markdown('<link rel="manifest" href="manifest.json">', unsafe_allow_html=True)
def __isInvalidResponse(response: str):
if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response:
U.pprint("new line followed by small case char")
return True
if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1:
U.pprint("lot of consecutive repeating words")
return True
if len(re.findall(r'\n\n', response)) > 20:
U.pprint("lots of paragraphs")
return True
if C.EXCEPTION_KEYWORD in response:
U.pprint("LLM API threw exception")
return True
if ('{\n "questions"' in response) and (C.JSON_SEPARATOR not in response):
U.pprint("JSON response without json separator")
return True
if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response):
U.pprint("JSON response without json separator")
return True
if response.startswith(C.JSON_SEPARATOR):
U.pprint("only options with no text")
return True
def __matchingKeywordsCount(keywords: List[str], text: str):
return sum([
1 if keyword in text else 0
for keyword in keywords
])
def __getMessages():
def getContextSize():
currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100
U.pprint(f"{currContextSize=}")
return currContextSize
while getContextSize() > MAX_CONTEXT:
U.pprint("Context size exceeded, removing first message")
st.session_state.messages.pop(0)
return st.session_state.messages
def __logLlmRequest(messagesFormatted: list, model: str):
contextSize = __countTokens(messagesFormatted)
U.pprint(f"{contextSize=} | {model}")
# U.pprint(f"{messagesFormatted=}")
tools = [
toolsInfo["getGoogleSearchResults"]["schema"],
]
def __showToolResponse(toolResponseDisplay: dict):
msg = toolResponseDisplay.get("text")
icon = toolResponseDisplay.get("icon")
col1, col2 = st.columns([1, 20])
with col1:
st.image(
icon or C.TOOL_ICON,
width=30
)
with col2:
if "`" not in msg:
st.markdown(f"`{msg}`")
else:
st.markdown(msg)
def __addToolCallToMsgs(toolCall: dict):
if isClaudeModel:
st.session_state.messages.append(toolCall)
else:
st.session_state.messages.append(
{
"role": "assistant",
"tool_calls": [
{
"id": toolCall.id,
"function": {
"name": toolCall.function.name,
"arguments": toolCall.function.arguments,
},
"type": toolCall.type,
}
],
}
)
def __processToolCalls(toolCalls):
for toolCall in toolCalls:
functionName = toolCall.function.name
functionToCall = toolsInfo[functionName]["func"]
functionArgsStr = toolCall.function.arguments
U.pprint(f"{functionName=} | {functionArgsStr=}")
functionArgs = json.loads(functionArgsStr)
functionResult = functionToCall(**functionArgs)
functionResponse = functionResult.get("response")
responseDisplay = functionResult.get("display")
U.pprint(f"{functionResponse=}")
if responseDisplay:
__showToolResponse(responseDisplay)
st.session_state.toolResponseDisplay = responseDisplay
__addToolCallToMsgs(toolCall)
st.session_state.messages.append({
"role": "tool",
"tool_call_id": toolCall.id,
"name": functionName,
"content": functionResponse,
})
def __processClaudeToolCalls(toolResponse):
toolCall = toolResponse[1]
functionName = toolCall.name
functionToCall = toolsInfo[functionName]["func"]
functionArgs = toolCall.input
functionResult = functionToCall(**functionArgs)
functionResponse = functionResult.get("response")
responseDisplay = functionResult.get("display")
U.pprint(f"{functionResponse=}")
if responseDisplay:
__showToolResponse(responseDisplay)
st.session_state.toolResponseDisplay = responseDisplay
__addToolCallToMsgs({
"role": "assistant",
"content": toolResponse
})
st.session_state.messages.append({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": toolCall.id,
"content": functionResponse,
}],
})
def __dedupeToolCalls(toolCalls: list):
toolCallsDict = {}
for toolCall in toolCalls:
funcName = toolCall.name if isClaudeModel else toolCall.function.name
toolCallsDict[funcName] = toolCall
dedupedToolCalls = list(toolCallsDict.values())
if len(toolCalls) != len(dedupedToolCalls):
U.pprint("Deduped tool calls!")
U.pprint(f"{toolCalls=} -> {dedupedToolCalls=}")
return dedupedToolCalls
def __getClaudeTools():
claudeTools = []
for tool in tools:
funcInfo = tool["function"]
name = funcInfo["name"]
description = funcInfo["description"]
schema = funcInfo["parameters"]
claudeTools.append({
"name": name,
"description": description,
"input_schema": schema,
})
return claudeTools
def __removeFunctionCall(response: str):
pattern = r'<function=getGoogleSearchResults>\{"query": ".*?"\}<function>'
return re.sub(pattern, '', response)
def predict(model: str = None, attempts=0):
model = model or MODEL
messagesFormatted = []
try:
if isClaudeModel:
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted, model)
responseMessage = client.messages.create(
model=model,
messages=messagesFormatted,
system=C.SYSTEM_MSG,
temperature=0.5,
max_tokens=4000,
tools=__getClaudeTools()
)
responseMessageContent = responseMessage.content
responseContent = responseMessageContent[0].text
toolCalls = []
if len(responseMessageContent) > 1:
toolCalls = [responseMessageContent[1]]
else:
messagesFormatted = [{"role": "system", "content": C.SYSTEM_MSG}]
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted, model)
response = client.chat.completions.create(
model=model,
messages=messagesFormatted,
temperature=0.6,
max_tokens=4000,
stream=False,
tools=tools
)
responseMessage = response.choices[0].message
responseContent = responseMessage.content
if responseContent and '<function=' in responseContent:
U.pprint(f"Wrong toolCall response: {responseContent}")
if attempts < 3:
U.pprint(f"Retrying...{attempts + 1}/3")
time.sleep(0.2)
return predict(model, attempts + 1)
else:
responseContent = __removeFunctionCall(responseContent)
if "<function=" in responseContent:
U.pprint("Switching to TOOLS_MODEL")
return predict(TOOLS_MODEL)
toolCalls = responseMessage.tool_calls
# U.pprint(f"{responseMessage=}")
# U.pprint(f"{responseContent=}")
U.pprint(f"{toolCalls=}")
if toolCalls:
toolCalls = __dedupeToolCalls(toolCalls)
U.pprint("Deduping done!")
try:
if isClaudeModel:
__processClaudeToolCalls(responseMessage.content)
else:
__processToolCalls(toolCalls)
return predict()
except Exception as e:
U.pprint(e)
else:
return responseContent
except Exception as e:
U.pprint(f"LLM API Error: {e}")
return f"{C.EXCEPTION_KEYWORD} | {e}"
def __generateImage(prompt: str):
fluxClient = Client("black-forest-labs/FLUX.1-schnell")
result = fluxClient.predict(
prompt=prompt,
seed=0,
randomize_seed=True,
width=1024,
height=768,
num_inference_steps=4,
api_name="/infer"
)
U.pprint(f"imageResult={result}")
return result
def __resetButtonState():
st.session_state.buttonValue = ""
if "ipAddress" not in st.session_state:
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for")
if "chatHistory" not in st.session_state:
st.session_state.chatHistory = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "buttonValue" not in st.session_state:
__resetButtonState()
st.session_state.toolResponseDisplay = {}
U.pprint("\n")
U.pprint("\n")
U.applyCommonStyles()
st.title("Dr Newtons PG research Ai")
for chat in st.session_state.chatHistory:
role = chat["role"]
content = chat["content"]
imagePath = chat.get("image")
toolResponseDisplay = chat.get("toolResponseDisplay")
avatar = C.AI_ICON if role == "assistant" else C.USER_ICON
with st.chat_message(role, avatar=avatar):
st.markdown(content)
if toolResponseDisplay:
__showToolResponse(toolResponseDisplay)
if imagePath:
st.image(imagePath)
# U.pprint(f"{st.session_state.buttonValue=}")
# U.pprint(f"{st.session_state.selectedStory=}")
# U.pprint(f"{st.session_state.startMsg=}")
if prompt := (
st.chat_input("Ask anything ...")
or st.session_state["buttonValue"]
):
__resetButtonState()
with st.chat_message("user", avatar=C.USER_ICON):
st.markdown(prompt)
U.pprint(f"{prompt=}")
st.session_state.chatHistory.append({"role": "user", "content": prompt })
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant", avatar=C.AI_ICON):
responseContainer = st.empty()
def __printAndGetResponse():
response = ""
responseContainer.image(C.TEXT_LOADER)
responseGenerator = predict()
for chunk in responseGenerator:
response += chunk
if __isInvalidResponse(response):
U.pprint(f"InvalidResponse={response}")
return
if C.JSON_SEPARATOR not in response:
responseContainer.markdown(response)
return response
response = __printAndGetResponse()
while not response:
U.pprint("Empty response. Retrying..")
time.sleep(0.7)
response = __printAndGetResponse()
U.pprint(f"{response=}")
def selectButton(optionLabel):
st.session_state["buttonValue"] = optionLabel
U.pprint(f"Selected: {optionLabel}")
rawResponse = response
responseParts = response.split(C.JSON_SEPARATOR)
jsonStr = None
if len(responseParts) > 1:
[response, jsonStr] = responseParts
imagePath = None
# imageContainer = st.empty()
# try:
# (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response)
# if imagePrompt:
# imgContainer = imageContainer.container()
# imgContainer.write(
# f"""
# <div class='blinking code'>
# {loaderText}
# </div>
# """,
# unsafe_allow_html=True
# )
# # imgContainer.markdown(f"`{loaderText}`")
# imgContainer.image(C.IMAGE_LOADER)
# (imagePath, seed) = __generateImage(imagePrompt)
# imageContainer.image(imagePath)
# except Exception as e:
# U.pprint(e)
# imageContainer.empty()
toolResponseDisplay = st.session_state.toolResponseDisplay
st.session_state.chatHistory.append({
"role": "assistant",
"content": response,
"image": imagePath,
"toolResponseDisplay": toolResponseDisplay
})
st.session_state.messages.append({
"role": "assistant",
"content": rawResponse,
})
if jsonStr:
try:
jsonStr = jsonStr.replace("```", "")
json.loads(jsonStr)
jsonObj = json.loads(jsonStr)
questions = jsonObj.get("questions")
action = jsonObj.get("action")
if questions:
for option in questions:
st.button(
option["label"],
key=option["id"],
on_click=lambda label=option["label"]: selectButton(label)
)
elif action:
pass
except Exception as e:
U.pprint(e)
# if st.button("Rerun"):
# # __resetButtonState()
# st.session_state.chatHistory = []
# st.session_state.messages = []
# st.rerun()