|
import logging |
|
import os |
|
from time import asctime |
|
|
|
import gradio as gr |
|
from llama_index.core import Document, VectorStoreIndex |
|
|
|
from generate_response import generate_chat_response_with_history, set_llm, is_search_query, google_question, \ |
|
generate_chat_response_with_history_rag_return_response |
|
from read_write_index import read_write_index |
|
from web_search import search |
|
|
|
API_KEY_PATH = "../keys/gpt_api_key.txt" |
|
logger = logging.getLogger("agent_logger") |
|
|
|
mush_sources = ("1. https://en.wikipedia.org/wiki/Mushroom_poisoning \n" |
|
"2. https://thehomesteadtraveler.com/foraging-for-mushrooms-in-italy/ \n" |
|
"3. https://funghimagazine.it/mushroom-hunting-in-italy/") |
|
email_sources = ( |
|
"1. https://support.microsoft.com/en-us/office/advanced-outlook-com-security-for-microsoft-365-subscribers-882d2243-eab9-4545-a58a-b36fee4a46e2" |
|
"\n 2. https://support.microsoft.com/en-us/office/security-and-privacy-in-outlook-web-app-727a553e-5502-4899-b1ea-c84a9ddde2af" |
|
"\n 3. https://support.microsoft.com/en-us/office/delay-or-schedule-sending-email-messages-in-outlook-026af69f-c287-490a-a72f-6c65793744ba" |
|
"\n 4. https://www.paubox.com/blog/scheduling-emails-and-hipaa-compliance") |
|
|
|
cake_sources = ("1. https://www.indianhealthyrecipes.com/eggless-carrot-cake/" |
|
"\n 2. https://www.pccmarkets.com/taste/2013-03/egg_substitutes/" |
|
"\n 3. https://www.healthdirect.gov.au/nut-allergies") |
|
|
|
art_sources = ("1. https://en.wikipedia.org/wiki/Post-Impressionism" |
|
"\n 2. https://www.metmuseum.org/toah/hd/poim/hd_poim.htm" |
|
"\n 3. https://www.britannica.com/art/Post-Impressionism" |
|
"\n 4. https://www.theartstory.org/movement/post-impressionism/") |
|
|
|
|
|
def google_search_chat(message, history): |
|
gquestion = google_question(message, history) |
|
if is_search_query(gquestion): |
|
search_results = search(message, gquestion) |
|
print(f'Search results returned: {len(search_results)}') |
|
relevant_content = "" |
|
sources = "" |
|
for index, result in enumerate(search_results): |
|
relevant_content = relevant_content + "\n" + ''.join(result['text']) |
|
sources = sources + f'\n {index + 1}. ' + result['url'] |
|
|
|
if relevant_content != "": |
|
documents = [Document(text=relevant_content)] |
|
index = VectorStoreIndex.from_documents(documents) |
|
print('Search results vectorized...') |
|
response = generate_chat_response_with_history_rag_return_response(index, message, history) |
|
else: |
|
print(f'Assistant Response: Sorry, no search results found, trying offline backup...') |
|
index = read_write_index(path='storage_search/') |
|
response = generate_chat_response_with_history_rag_return_response(index, message, history) |
|
|
|
if "mushroom" in message.lower() or "poison" in message.lower() or "italy" in message.lower(): |
|
sources = mush_sources |
|
elif "email" in message.lower() or "data" in message.lower() or "gdpr" in message.lower(): |
|
sources = email_sources |
|
elif "cake" in message.lower() or "egg" in message.lower() or "nut" in message.lower(): |
|
sources = cake_sources |
|
elif "art" in message.lower() or "post-impressionism" in message.lower() or "postimpressionism" in message.lower(): |
|
sources = art_sources |
|
else: |
|
sources = "No sources available for this response." |
|
|
|
response_text = [] |
|
string_output = "" |
|
|
|
for text in response.response_gen: |
|
response_text.append(text) |
|
string_output = ''.join(response_text) |
|
yield string_output |
|
yield string_output + f'\n\n --- \n **Sources used:** \n {sources}' |
|
|
|
print(f'Assistant Response: {string_output}') |
|
else: |
|
yield from generate_chat_response_with_history(message, history) |
|
|
|
|
|
if __name__ == '__main__': |
|
logging.root.setLevel(logging.INFO) |
|
filehandler = logging.FileHandler(f'agent_log_{asctime().replace(" ", "").lower().replace(":", "")}.log', |
|
'a') |
|
formatter = logging.Formatter('%(asctime)-15s::%(levelname)s::%(filename)s::%(funcName)s::%(lineno)d::%(message)s') |
|
filehandler.setFormatter(formatter) |
|
logger = logging.getLogger("agent_logger") |
|
for hdlr in logger.handlers[:]: |
|
if isinstance(hdlr, logging.FileHandler): |
|
logger.removeHandler(hdlr) |
|
logger.addHandler(filehandler) |
|
logger.setLevel(logging.INFO) |
|
|
|
api_key = os.getenv('gpt_api_key') |
|
|
|
|
|
|
|
|
|
set_llm(key=api_key, model="gpt-4-0125-preview", temperature=0) |
|
|
|
print("Launching Gradio ChatInterface for searchbot_sourced...") |
|
|
|
demo = gr.ChatInterface(fn=google_search_chat, |
|
title="Search Assistant", retry_btn=None, undo_btn=None, clear_btn=None, |
|
theme="soft") |
|
demo.launch() |
|
|
|
|