|
from collections import deque |
|
from langchain_community.vectorstores import Pinecone |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_huggingface import HuggingFaceEndpoint |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import RetrievalQA |
|
from pinecone import Pinecone, ServerlessSpec |
|
from langchain_pinecone import PineconeVectorStore |
|
from dotenv import load_dotenv |
|
from langchain import hub |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain.chains.retrieval import create_retrieval_chain |
|
from langchain_openai import ChatOpenAI |
|
from langchain.chains import LLMChain |
|
from langchain_community.retrievers import BM25Retriever |
|
from langchain_core.runnables import RunnableLambda |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
import pinecone |
|
import time |
|
import requests |
|
import os |
|
import gradio as gr |
|
import re |
|
import gc |
|
import tensorflow as tf |
|
import openai |
|
import base64 |
|
|
|
openai_api = os.getenv("open_api_key") |
|
pinecone_api = os.getenv("pinecone_api_key") |
|
|
|
index = "wooribank" |
|
|
|
pc = Pinecone(api_key=pinecone_api) |
|
|
|
|
|
|
|
embeddings = OpenAIEmbeddings ( |
|
model = "text-embedding-ada-002", |
|
openai_api_key = openai_api |
|
) |
|
|
|
llm = ChatOpenAI( |
|
openai_api_key=openai_api, |
|
model_name='gpt-4o', |
|
temperature=0.0 |
|
) |
|
|
|
vectorstore = PineconeVectorStore( |
|
index_name="wooribank", |
|
embedding=embeddings, |
|
pinecone_api_key=pinecone_api, |
|
namespace="finance" |
|
) |
|
|
|
def search_documents(query, k=3): |
|
similarity_threshold = 0.8 |
|
results = vectorstore.similarity_search_with_score(query, k=k) |
|
|
|
filtered_results = [doc for doc, score in results if score >= similarity_threshold] |
|
|
|
if not filtered_results: |
|
return None |
|
|
|
return "\n\n".join([doc.page_content for doc in filtered_results]) |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=( |
|
"You are a helpful Korean AI assistant that answers financial questions based on the provided information.\n\n" |
|
"### Context ###\n" |
|
"{context}\n\n" |
|
"### Question ###\n" |
|
"{question}\n\n" |
|
"### Answer ###\n" |
|
"Based on the provided context, provide a relevant and helpful answer in Korean." |
|
"Your response should be concise, with no more than 5 sentences." |
|
"If the context provides relevant information, answer the question clearly without adding unnecessary details." |
|
"If and only if the context provides no relevant information at all, then respond with: '๋๋ตํ๊ธฐ์ ์ถฉ๋ถํ ์ ๋ณด๋ฅผ ์์ง ๋ชปํฉ๋๋ค. ๋ค๋ฅธ ์ง๋ฌธ์ ํด์ฃผ์๊ฒ ์ด์?'." |
|
"If the context contains Woori Bank product details, and the query is not directly about Woori Bank, mention the product naturally at the end as an example." |
|
) |
|
) |
|
|
|
qa_chain = ( |
|
{ |
|
"context": RunnableLambda(search_documents), |
|
"question": RunnablePassthrough(), |
|
} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
|
|
def image_to_base64(file_path): |
|
with open(file_path, "rb") as file: |
|
return base64.b64encode(file.read()).decode("utf-8") |
|
|
|
|
|
loc = os.path.join(os.path.dirname(__file__), "temp.png") |
|
image_base64 = image_to_base64(loc) |
|
|
|
|
|
|
|
|
|
css = f""" |
|
.gradio-container {{ |
|
background-image: url('data:image/jpeg;base64,{image_base64}'); |
|
background-size: 700px 300px; |
|
background-position: left bottom; |
|
background-repeat: no-repeat; |
|
}} |
|
|
|
#custom-button {{ |
|
background-color: #0067ac; |
|
color: white; |
|
border: none; |
|
border-radius: 5px; |
|
padding: 10px 20px; |
|
cursor: pointer; |
|
}} |
|
|
|
#custom-button:hover {{ |
|
background-color: #2482C5; |
|
}} |
|
|
|
#chatbot-box {{ |
|
max-height: 700px; |
|
overflow-y: auto; |
|
border: 1px solid #ccc; |
|
border-radius: 10px; |
|
padding: 10px; |
|
background-color: #f9f9f9; |
|
}} |
|
|
|
#chatbot-box .message.user {{ |
|
background-color: #2482C5; |
|
color: white; |
|
border-radius: 10px; |
|
padding: 10px; |
|
text-align: right; |
|
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1); |
|
animation: fadeIn 0.5s ease-in-out; |
|
}} |
|
|
|
#chatbot-box .message.bot {{ |
|
background-color: #a0a0a0; |
|
color: white; |
|
border-radius: 10px; |
|
padding: 10px; |
|
text-align: left; |
|
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1); |
|
animation: fadeIn 0.5s ease-in-out; |
|
}} |
|
|
|
@keyframes fadeIn {{ |
|
from {{ |
|
opacity: 0; |
|
}} |
|
to {{ |
|
opacity: 1; |
|
}} |
|
}} |
|
|
|
""" |
|
|
|
chat_history = [] |
|
|
|
def chatbot_interface(user_input): |
|
|
|
chat_history.append((user_input, "๋ต๋ณ ์์ฑ ์ค...")) |
|
yield chat_history |
|
|
|
answer = qa_chain.invoke(user_input) |
|
|
|
temp = "" |
|
for char in answer: |
|
temp += char |
|
chat_history[-1] = (user_input, temp) |
|
yield chat_history |
|
time.sleep(0.05) |
|
|
|
right_aligned_text = """ |
|
<div align="right"> |
|
<h3>This project was developed as part of the Woori Bank internship program.</h3> |
|
</div> |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("# Financial Helper Chatbot for Students") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("") |
|
gr.Markdown("") |
|
gr.Markdown("") |
|
gr.Markdown("# Your Nearest Financial Helper") |
|
gr.Markdown("## by Woori Bank") |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("") |
|
gr.Markdown("") |
|
chatbot = gr.Chatbot(label="๋ํ ๋ด์ญ", elem_id="chatbot-box") |
|
user_input = gr.Textbox(label="", placeholder="๋ฉ์์ง๋ฅผ ์
๋ ฅํ์ธ์...", lines=1, show_label=False) |
|
submit_button = gr.Button("์ ์ก", elem_id="custom-button") |
|
gr.Markdown("") |
|
|
|
gr.Markdown(right_aligned_text) |
|
|
|
submit_button.click( |
|
chatbot_interface, |
|
inputs = user_input, |
|
outputs=chatbot, |
|
queue=True |
|
) |
|
|
|
user_input.submit( |
|
chatbot_interface, |
|
inputs = user_input, |
|
outputs=chatbot, |
|
queue=True |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |
|
|