Spaces:
Runtime error
Runtime error
test party filter
Browse files- Home.py +2 -2
- src/chatbot.py +29 -20
Home.py
CHANGED
@@ -39,10 +39,10 @@ with gr.Blocks() as App:
|
|
39 |
file = gr.File(file_types=[".xlsx", ".csv", ".json"], visible=False)
|
40 |
|
41 |
#Keyword Search on click
|
42 |
-
def search(keyword, n, party): #ToDo: Include party
|
43 |
return {
|
44 |
output_col: gr.Column(visible=True),
|
45 |
-
results_df: keyword_search(query=keyword, n=n),
|
46 |
}
|
47 |
|
48 |
search_btn.click(
|
|
|
39 |
file = gr.File(file_types=[".xlsx", ".csv", ".json"], visible=False)
|
40 |
|
41 |
#Keyword Search on click
|
42 |
+
def search(keyword, n, party): #ToDo: Include party and timedate
|
43 |
return {
|
44 |
output_col: gr.Column(visible=True),
|
45 |
+
results_df: keyword_search(query=keyword, n=n, party_filter=party),
|
46 |
}
|
47 |
|
48 |
search_btn.click(
|
src/chatbot.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
-
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
2 |
from langchain_core.prompts import ChatPromptTemplate
|
3 |
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
4 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
5 |
|
6 |
from src.vectordatabase import RAG, get_vectorstore
|
7 |
import pandas as pd
|
8 |
-
import
|
9 |
-
#from dotenv import load_dotenv, find_dotenv
|
10 |
|
11 |
#Load environmental variables from .env-file
|
12 |
#load_dotenv(find_dotenv())
|
@@ -63,22 +61,33 @@ def chatbot(message, history, db=db, llm=llm, prompt=prompt2):
|
|
63 |
return response
|
64 |
|
65 |
# Retrieve speech contents based on keywords
|
66 |
-
def keyword_search(query,n=10, db=db, embeddings=embeddings):
|
67 |
query_embedding = embeddings.embed_query(query)
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
return df_res
|
|
|
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
3 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
|
5 |
from src.vectordatabase import RAG, get_vectorstore
|
6 |
import pandas as pd
|
7 |
+
from dotenv import load_dotenv, find_dotenv
|
|
|
8 |
|
9 |
#Load environmental variables from .env-file
|
10 |
#load_dotenv(find_dotenv())
|
|
|
61 |
return response
|
62 |
|
63 |
# Retrieve speech contents based on keywords
|
64 |
+
def keyword_search(query,n=10, db=db, embeddings=embeddings, method='ss', party_filter = ''):
|
65 |
query_embedding = embeddings.embed_query(query)
|
66 |
+
if method == 'mmr':
|
67 |
+
df_res = pd.DataFrame(columns=['Speech Content','Date', 'Party', 'Relevance']) # Add Date/Party/Politician
|
68 |
+
results = db.max_marginal_relevance_search_with_score_by_vector(query_embedding, k = n, fetch_k = n + 10) #Add filter
|
69 |
+
for doc in results:
|
70 |
+
speech_content = doc[0].page_content
|
71 |
+
speech_date = doc[0].metadata["date"]
|
72 |
+
party = doc[0].metadata["party"]
|
73 |
+
score = round(doc[1], ndigits=2) # Relevance based on relevance search
|
74 |
+
df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content],
|
75 |
+
'Date': [speech_date],
|
76 |
+
'Party': [party],
|
77 |
+
'Relevance': [score]})], ignore_index=True)
|
78 |
+
df_res.sort_values('Relevance', inplace=True, ascending=True)
|
79 |
+
else:
|
80 |
+
df_res = pd.DataFrame(columns=['Speech Content','Date', 'Party']) # Add Date/Party/Politician #Add filter
|
81 |
+
results = db.similarity_search_by_vector(query_embedding, k = n, filter={"party": party_filter})
|
82 |
+
for doc in results:
|
83 |
+
party = doc.metadata["party"]
|
84 |
+
#Filter by party input
|
85 |
+
#if party != party_filter or party_filter == '':
|
86 |
+
# continue
|
87 |
+
speech_content = doc.page_content
|
88 |
+
speech_date = doc.metadata["date"]
|
89 |
+
|
90 |
+
df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content],
|
91 |
+
'Date': [speech_date],
|
92 |
+
'Party': [party]})], ignore_index=True)
|
93 |
return df_res
|