Spaces:
Paused
Paused
Hemang Thakur
commited on
Commit
·
25f9610
1
Parent(s):
d8479e9
removed neo4j and replaced with rustworkx
Browse files- frontend/src/Components/AiComponents/Graph.js +3 -3
- frontend/src/Components/IntialSetting.js +11 -11
- main.py +10 -12
- requirements.txt +1 -0
- src/rag/graph_rag.py +1433 -0
frontend/src/Components/AiComponents/Graph.js
CHANGED
@@ -2,19 +2,19 @@ import React, { useState, useEffect } from 'react';
|
|
2 |
import { FaTimes } from 'react-icons/fa';
|
3 |
import './Graph.css';
|
4 |
|
5 |
-
export default function Graph({ open, onClose
|
6 |
const [graphHtml, setGraphHtml] = useState("");
|
7 |
const [loading, setLoading] = useState(true);
|
8 |
const [error, setError] = useState("");
|
9 |
|
10 |
useEffect(() => {
|
11 |
-
if (open
|
12 |
setLoading(true);
|
13 |
setError("");
|
14 |
fetch("/action/graph", {
|
15 |
method: "POST",
|
16 |
headers: { "Content-Type": "application/json" },
|
17 |
-
body: JSON.stringify(
|
18 |
})
|
19 |
.then(res => res.json())
|
20 |
.then(data => {
|
|
|
2 |
import { FaTimes } from 'react-icons/fa';
|
3 |
import './Graph.css';
|
4 |
|
5 |
+
export default function Graph({ open, onClose }) {
|
6 |
const [graphHtml, setGraphHtml] = useState("");
|
7 |
const [loading, setLoading] = useState(true);
|
8 |
const [error, setError] = useState("");
|
9 |
|
10 |
useEffect(() => {
|
11 |
+
if (open) {
|
12 |
setLoading(true);
|
13 |
setError("");
|
14 |
fetch("/action/graph", {
|
15 |
method: "POST",
|
16 |
headers: { "Content-Type": "application/json" },
|
17 |
+
body: JSON.stringify({})
|
18 |
})
|
19 |
.then(res => res.json())
|
20 |
.then(data => {
|
frontend/src/Components/IntialSetting.js
CHANGED
@@ -90,9 +90,9 @@ function IntialSetting(props) {
|
|
90 |
const modelAPIKeys = form.elements["model-api"].value;
|
91 |
const braveAPIKey = form.elements["brave-api"].value;
|
92 |
const proxyList = form.elements["proxy-list"].value;
|
93 |
-
const neo4jURL = form.elements["neo4j-url"].value;
|
94 |
-
const neo4jUsername = form.elements["neo4j-username"].value;
|
95 |
-
const neo4jPassword = form.elements["neo4j-password"].value;
|
96 |
|
97 |
// Validate required fields and collect missing field names
|
98 |
const missingFields = [];
|
@@ -100,9 +100,9 @@ function IntialSetting(props) {
|
|
100 |
if (!modelName || modelName.trim() === "") missingFields.push("Model Name");
|
101 |
if (!modelAPIKeys || modelAPIKeys.trim() === "") missingFields.push("Model API Key");
|
102 |
if (!braveAPIKey || braveAPIKey.trim() === "") missingFields.push("Brave Search API Key");
|
103 |
-
if (!neo4jURL || neo4jURL.trim() === "") missingFields.push("Neo4j URL");
|
104 |
-
if (!neo4jUsername || neo4jUsername.trim() === "") missingFields.push("Neo4j Username");
|
105 |
-
if (!neo4jPassword || neo4jPassword.trim() === "") missingFields.push("Neo4j Password");
|
106 |
|
107 |
// If any required fields are missing, show an error notification
|
108 |
if (missingFields.length > 0) {
|
@@ -120,9 +120,9 @@ function IntialSetting(props) {
|
|
120 |
"Model_Name": modelName,
|
121 |
"Model_API_Keys": modelAPIKeys,
|
122 |
"Brave_Search_API_Key": braveAPIKey,
|
123 |
-
"Neo4j_URL": neo4jURL,
|
124 |
-
"Neo4j_Username": neo4jUsername,
|
125 |
-
"Neo4j_Password": neo4jPassword,
|
126 |
"Model_Temperature": modelTemperature,
|
127 |
"Model_Top_P": modelTopP,
|
128 |
};
|
@@ -282,7 +282,7 @@ function IntialSetting(props) {
|
|
282 |
</div>
|
283 |
|
284 |
{/* Neo4j Configuration */}
|
285 |
-
<div className="form-group">
|
286 |
<label htmlFor="neo4j-url">Neo4j URL</label>
|
287 |
<input
|
288 |
type="text"
|
@@ -321,7 +321,7 @@ function IntialSetting(props) {
|
|
321 |
{showPassword ? <FaEyeSlash /> : <FaEye />}
|
322 |
</IconButton>
|
323 |
</div>
|
324 |
-
</div>
|
325 |
|
326 |
{/* Model Temperature and Top-P */}
|
327 |
<div className="form-group">
|
|
|
90 |
const modelAPIKeys = form.elements["model-api"].value;
|
91 |
const braveAPIKey = form.elements["brave-api"].value;
|
92 |
const proxyList = form.elements["proxy-list"].value;
|
93 |
+
// const neo4jURL = form.elements["neo4j-url"].value;
|
94 |
+
// const neo4jUsername = form.elements["neo4j-username"].value;
|
95 |
+
// const neo4jPassword = form.elements["neo4j-password"].value;
|
96 |
|
97 |
// Validate required fields and collect missing field names
|
98 |
const missingFields = [];
|
|
|
100 |
if (!modelName || modelName.trim() === "") missingFields.push("Model Name");
|
101 |
if (!modelAPIKeys || modelAPIKeys.trim() === "") missingFields.push("Model API Key");
|
102 |
if (!braveAPIKey || braveAPIKey.trim() === "") missingFields.push("Brave Search API Key");
|
103 |
+
// if (!neo4jURL || neo4jURL.trim() === "") missingFields.push("Neo4j URL");
|
104 |
+
// if (!neo4jUsername || neo4jUsername.trim() === "") missingFields.push("Neo4j Username");
|
105 |
+
// if (!neo4jPassword || neo4jPassword.trim() === "") missingFields.push("Neo4j Password");
|
106 |
|
107 |
// If any required fields are missing, show an error notification
|
108 |
if (missingFields.length > 0) {
|
|
|
120 |
"Model_Name": modelName,
|
121 |
"Model_API_Keys": modelAPIKeys,
|
122 |
"Brave_Search_API_Key": braveAPIKey,
|
123 |
+
// "Neo4j_URL": neo4jURL,
|
124 |
+
// "Neo4j_Username": neo4jUsername,
|
125 |
+
// "Neo4j_Password": neo4jPassword,
|
126 |
"Model_Temperature": modelTemperature,
|
127 |
"Model_Top_P": modelTopP,
|
128 |
};
|
|
|
282 |
</div>
|
283 |
|
284 |
{/* Neo4j Configuration */}
|
285 |
+
{/* <div className="form-group">
|
286 |
<label htmlFor="neo4j-url">Neo4j URL</label>
|
287 |
<input
|
288 |
type="text"
|
|
|
321 |
{showPassword ? <FaEyeSlash /> : <FaEye />}
|
322 |
</IconButton>
|
323 |
</div>
|
324 |
+
</div> */}
|
325 |
|
326 |
{/* Model Temperature and Top-P */}
|
327 |
<div className="form-group">
|
main.py
CHANGED
@@ -41,7 +41,7 @@ def initialize_components():
|
|
41 |
|
42 |
from src.search.search_engine import SearchEngine
|
43 |
from src.query_processing.query_processor import QueryProcessor
|
44 |
-
from src.rag.
|
45 |
from src.evaluation.evaluator import Evaluator
|
46 |
from src.reasoning.reasoner import Reasoner
|
47 |
from src.crawl.crawler import CustomCrawler
|
@@ -53,7 +53,7 @@ def initialize_components():
|
|
53 |
SESSION_STORE['search_engine'] = SearchEngine()
|
54 |
SESSION_STORE['query_processor'] = QueryProcessor()
|
55 |
SESSION_STORE['crawler'] = CustomCrawler(max_concurrent_requests=1000)
|
56 |
-
SESSION_STORE['graph_rag'] =
|
57 |
SESSION_STORE['evaluator'] = Evaluator()
|
58 |
SESSION_STORE['reasoner'] = Reasoner()
|
59 |
SESSION_STORE['model'] = manager.get_llm()
|
@@ -580,12 +580,10 @@ def action_sources(payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
580 |
|
581 |
# Define the route for graph action to display the graph
|
582 |
@app.post("/action/graph")
|
583 |
-
def action_graph(
|
584 |
state = SESSION_STORE
|
585 |
-
|
586 |
try:
|
587 |
-
|
588 |
-
html_str = state['graph_rag'].display_graph(q)
|
589 |
|
590 |
return {"result": html_str}
|
591 |
except Exception as e:
|
@@ -621,9 +619,9 @@ async def update_settings(data: Dict[str, Any]):
|
|
621 |
multiple_api_keys = data.get("Model_API_Keys", "").strip()
|
622 |
brave_api_key = data.get("Brave_Search_API_Key", "").strip()
|
623 |
proxy_list = data.get("Proxy_List", "").strip()
|
624 |
-
neo4j_url = data.get("Neo4j_URL", "").strip()
|
625 |
-
neo4j_username = data.get("Neo4j_Username", "").strip()
|
626 |
-
neo4j_password = data.get("Neo4j_Password", "").strip()
|
627 |
model_temperature = str(data.get("Model_Temperature", 0.0))
|
628 |
model_top_p = str(data.get("Model_Top_P", 1.0))
|
629 |
|
@@ -637,9 +635,9 @@ async def update_settings(data: Dict[str, Any]):
|
|
637 |
env_updates.update(px)
|
638 |
|
639 |
env_updates["BRAVE_API_KEY"] = brave_api_key
|
640 |
-
env_updates["NEO4J_URI"] = neo4j_url
|
641 |
-
env_updates["NEO4J_USER"] = neo4j_username
|
642 |
-
env_updates["NEO4J_PASSWORD"] = neo4j_password
|
643 |
env_updates["MODEL_PROVIDER"] = prov_lower
|
644 |
env_updates["MODEL_NAME"] = model_name
|
645 |
env_updates["MODEL_TEMPERATURE"] = model_temperature
|
|
|
41 |
|
42 |
from src.search.search_engine import SearchEngine
|
43 |
from src.query_processing.query_processor import QueryProcessor
|
44 |
+
from src.rag.graph_rag import GraphRAG
|
45 |
from src.evaluation.evaluator import Evaluator
|
46 |
from src.reasoning.reasoner import Reasoner
|
47 |
from src.crawl.crawler import CustomCrawler
|
|
|
53 |
SESSION_STORE['search_engine'] = SearchEngine()
|
54 |
SESSION_STORE['query_processor'] = QueryProcessor()
|
55 |
SESSION_STORE['crawler'] = CustomCrawler(max_concurrent_requests=1000)
|
56 |
+
SESSION_STORE['graph_rag'] = GraphRAG(num_workers=os.cpu_count() * 2)
|
57 |
SESSION_STORE['evaluator'] = Evaluator()
|
58 |
SESSION_STORE['reasoner'] = Reasoner()
|
59 |
SESSION_STORE['model'] = manager.get_llm()
|
|
|
580 |
|
581 |
# Define the route for graph action to display the graph
|
582 |
@app.post("/action/graph")
|
583 |
+
def action_graph() -> Dict[str, Any]:
|
584 |
state = SESSION_STORE
|
|
|
585 |
try:
|
586 |
+
html_str = state['graph_rag'].display_graph()
|
|
|
587 |
|
588 |
return {"result": html_str}
|
589 |
except Exception as e:
|
|
|
619 |
multiple_api_keys = data.get("Model_API_Keys", "").strip()
|
620 |
brave_api_key = data.get("Brave_Search_API_Key", "").strip()
|
621 |
proxy_list = data.get("Proxy_List", "").strip()
|
622 |
+
# neo4j_url = data.get("Neo4j_URL", "").strip()
|
623 |
+
# neo4j_username = data.get("Neo4j_Username", "").strip()
|
624 |
+
# neo4j_password = data.get("Neo4j_Password", "").strip()
|
625 |
model_temperature = str(data.get("Model_Temperature", 0.0))
|
626 |
model_top_p = str(data.get("Model_Top_P", 1.0))
|
627 |
|
|
|
635 |
env_updates.update(px)
|
636 |
|
637 |
env_updates["BRAVE_API_KEY"] = brave_api_key
|
638 |
+
# env_updates["NEO4J_URI"] = neo4j_url
|
639 |
+
# env_updates["NEO4J_USER"] = neo4j_username
|
640 |
+
# env_updates["NEO4J_PASSWORD"] = neo4j_password
|
641 |
env_updates["MODEL_PROVIDER"] = prov_lower
|
642 |
env_updates["MODEL_NAME"] = model_name
|
643 |
env_updates["MODEL_TEMPERATURE"] = model_temperature
|
requirements.txt
CHANGED
@@ -18,6 +18,7 @@ langchain_xai==0.1.1
|
|
18 |
langgraph==0.2.62
|
19 |
model2vec==0.3.3
|
20 |
neo4j==5.26.0
|
|
|
21 |
openai==1.59.3
|
22 |
protobuf==4.23.4
|
23 |
PyPDF2==3.0.1
|
|
|
18 |
langgraph==0.2.62
|
19 |
model2vec==0.3.3
|
20 |
neo4j==5.26.0
|
21 |
+
rustworkx==0.16.0
|
22 |
openai==1.59.3
|
23 |
protobuf==4.23.4
|
24 |
PyPDF2==3.0.1
|
src/rag/graph_rag.py
ADDED
@@ -0,0 +1,1433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import time
|
4 |
+
import asyncio
|
5 |
+
import torch
|
6 |
+
import uuid
|
7 |
+
import rustworkx as rx
|
8 |
+
import numpy as np
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
from typing import List, Dict, Any
|
11 |
+
from pyvis.network import Network
|
12 |
+
from src.query_processing.late_chunking.late_chunker import LateChunker
|
13 |
+
from src.query_processing.query_processor import QueryProcessor
|
14 |
+
from src.reasoning.reasoner import Reasoner
|
15 |
+
from src.utils.api_key_manager import APIKeyManager
|
16 |
+
from src.search.search_engine import SearchEngine
|
17 |
+
from src.crawl.crawler import CustomCrawler #, Crawler
|
18 |
+
from sentence_transformers import SentenceTransformer
|
19 |
+
from bert_score.scorer import BERTScorer
|
20 |
+
|
21 |
+
class GraphRAG:
|
22 |
+
def __init__(self, num_workers: int = 1):
|
23 |
+
"""Initialize graph and required components."""
|
24 |
+
# Dictionary to store multiple graphs
|
25 |
+
self.graphs = {}
|
26 |
+
self.current_graph_id = None
|
27 |
+
|
28 |
+
# Component initialization
|
29 |
+
self.num_workers = num_workers
|
30 |
+
self.search_engine = SearchEngine()
|
31 |
+
self.query_processor = QueryProcessor()
|
32 |
+
self.reasoner = Reasoner()
|
33 |
+
# self.crawler = Crawler(verbose=True)
|
34 |
+
self.custom_crawler = CustomCrawler(max_concurrent_requests=1000)
|
35 |
+
self.chunking = LateChunker()
|
36 |
+
self.llm = APIKeyManager().get_llm()
|
37 |
+
|
38 |
+
# Model initialization
|
39 |
+
self.model = SentenceTransformer(
|
40 |
+
"dunzhang/stella_en_400M_v5",
|
41 |
+
trust_remote_code=True,
|
42 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
43 |
+
)
|
44 |
+
self.scorer = BERTScorer(
|
45 |
+
model_type="roberta-base",
|
46 |
+
lang="en",
|
47 |
+
rescale_with_baseline=True,
|
48 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
49 |
+
)
|
50 |
+
|
51 |
+
# Counters and tracking
|
52 |
+
self.root_node_id = "QR"
|
53 |
+
self.node_counter = 0
|
54 |
+
self.sub_node_counter = 0
|
55 |
+
self.cross_connections = set()
|
56 |
+
|
57 |
+
# Thread pool
|
58 |
+
self.executor = ThreadPoolExecutor(max_workers=self.num_workers)
|
59 |
+
|
60 |
+
# Event callback
|
61 |
+
self.on_event_callback = None
|
62 |
+
|
63 |
+
def set_on_event_callback(self, callback):
|
64 |
+
"""Register a single callback to be triggered for various event types."""
|
65 |
+
self.on_event_callback = callback
|
66 |
+
|
67 |
+
async def emit_event(self, event_type: str, data: dict):
|
68 |
+
"""Helper method to safely emit an event if a callback is registered."""
|
69 |
+
if self.on_event_callback:
|
70 |
+
if asyncio.iscoroutinefunction(self.on_event_callback):
|
71 |
+
return await self.on_event_callback(event_type, data)
|
72 |
+
else:
|
73 |
+
return self.on_event_callback(event_type, data)
|
74 |
+
|
75 |
+
def _get_current_graph_data(self):
|
76 |
+
if self.current_graph_id is None or self.current_graph_id not in self.graphs:
|
77 |
+
raise Exception("Error: No current graph selected")
|
78 |
+
|
79 |
+
return self.graphs[self.current_graph_id]
|
80 |
+
|
81 |
+
def add_node(self, node_id: str, query: str, data: str = "", role: str = None):
|
82 |
+
"""Add a node to the current graph."""
|
83 |
+
graph_data = self._get_current_graph_data()
|
84 |
+
graph = graph_data["graph"]
|
85 |
+
node_map = graph_data["node_map"]
|
86 |
+
|
87 |
+
# Generate embedding
|
88 |
+
embedding = self.model.encode(query).tolist()
|
89 |
+
node_data = {
|
90 |
+
"id": node_id,
|
91 |
+
"query": query,
|
92 |
+
"data": data,
|
93 |
+
"role": role,
|
94 |
+
"embedding": embedding,
|
95 |
+
"pagerank": 0,
|
96 |
+
"graph_id": self.current_graph_id
|
97 |
+
}
|
98 |
+
node_index = graph.add_node(node_data)
|
99 |
+
node_map[node_id] = node_index
|
100 |
+
|
101 |
+
print(f"Added node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}'")
|
102 |
+
|
103 |
+
def _has_path(self, source_idx: int, target_idx: int) -> bool:
|
104 |
+
"""Helper method to check if there is a path from source to target in the current graph."""
|
105 |
+
graph_data = self._get_current_graph_data()
|
106 |
+
graph = graph_data["graph"]
|
107 |
+
visited = set()
|
108 |
+
stack = [source_idx]
|
109 |
+
|
110 |
+
while stack:
|
111 |
+
current = stack.pop()
|
112 |
+
|
113 |
+
if current == target_idx:
|
114 |
+
return True
|
115 |
+
|
116 |
+
if current in visited:
|
117 |
+
continue
|
118 |
+
|
119 |
+
visited.add(current)
|
120 |
+
for neighbor in graph.neighbors(current):
|
121 |
+
stack.append(neighbor)
|
122 |
+
|
123 |
+
return False
|
124 |
+
|
125 |
+
def add_edge(self, node1: str, node2: str, weight: float = 1.0, relationship_type: str = None):
|
126 |
+
"""Add an edge between two nodes in a way that preserves a DAG structure."""
|
127 |
+
if self.current_graph_id is None:
|
128 |
+
raise Exception("Error: No current graph selected")
|
129 |
+
|
130 |
+
if node1 == node2:
|
131 |
+
print(f"Cannot add edge to the same node {node1}!")
|
132 |
+
return
|
133 |
+
|
134 |
+
graph_data = self._get_current_graph_data()
|
135 |
+
graph = graph_data["graph"]
|
136 |
+
node_map = graph_data["node_map"]
|
137 |
+
|
138 |
+
if node1 not in node_map or node2 not in node_map:
|
139 |
+
print(f"One or both nodes {node1}, {node2} do not exist in the current graph.")
|
140 |
+
return
|
141 |
+
|
142 |
+
idx1 = node_map[node1]
|
143 |
+
idx2 = node_map[node2]
|
144 |
+
|
145 |
+
# Check if adding this edge would create a cycle (i.e. if there is a path from node2 to node1)
|
146 |
+
if self._has_path(idx2, idx1):
|
147 |
+
print(f"An edge between {node1} -> {node2} already exists or would create a cycle!")
|
148 |
+
return
|
149 |
+
|
150 |
+
if relationship_type and weight:
|
151 |
+
edge_data = {"type": relationship_type, "weight": weight}
|
152 |
+
graph.add_edge(idx1, idx2, edge_data)
|
153 |
+
else:
|
154 |
+
raise ValueError("Error: Relationship type and weight must be provided")
|
155 |
+
print(f"Added edge between '{node1}' and '{node2}' in graph '{self.current_graph_id}' (type='{relationship_type}', weight={weight})")
|
156 |
+
|
157 |
+
def edge_exists(self, node1: str, node2: str) -> bool:
|
158 |
+
"""Check if an edge exists between two nodes."""
|
159 |
+
graph_data = self._get_current_graph_data()
|
160 |
+
graph = graph_data["graph"]
|
161 |
+
node_map = graph_data["node_map"]
|
162 |
+
|
163 |
+
if node1 not in node_map or node2 not in node_map:
|
164 |
+
return False
|
165 |
+
idx1 = node_map[node1]
|
166 |
+
idx2 = node_map[node2]
|
167 |
+
|
168 |
+
for edge in graph.out_edges(idx1):
|
169 |
+
if edge[1] == idx2:
|
170 |
+
return True
|
171 |
+
|
172 |
+
return False
|
173 |
+
|
174 |
+
def graph_exists(self) -> bool:
|
175 |
+
"""Check if a graph exists."""
|
176 |
+
return self.current_graph_id is not None and self.current_graph_id in self.graphs and len(self.graphs[self.current_graph_id]["node_map"]) > 0
|
177 |
+
|
178 |
+
def get_graphs(self) -> list:
|
179 |
+
"""Get detailed information about all existing graphs and their nodes."""
|
180 |
+
result = []
|
181 |
+
for graph_id, data in self.graphs.items():
|
182 |
+
metadata = data["metadata"]
|
183 |
+
node_map = data["node_map"]
|
184 |
+
graph = data["graph"]
|
185 |
+
nodes_info = []
|
186 |
+
|
187 |
+
for node_id, idx in node_map.items():
|
188 |
+
node_data = graph.get_node_data(idx)
|
189 |
+
nodes_info.append({
|
190 |
+
"id": node_data.get("id"),
|
191 |
+
"query": node_data.get("query"),
|
192 |
+
"data": node_data.get("data"),
|
193 |
+
"role": node_data.get("role"),
|
194 |
+
"pagerank": node_data.get("pagerank")
|
195 |
+
})
|
196 |
+
edge_count = len(graph.edge_list())
|
197 |
+
result.append({
|
198 |
+
"graph_info": {
|
199 |
+
"graph_id": graph_id,
|
200 |
+
"created": metadata.get("created"),
|
201 |
+
"updated": metadata.get("updated"),
|
202 |
+
"node_count": len(node_map),
|
203 |
+
"edge_count": edge_count,
|
204 |
+
"nodes": nodes_info
|
205 |
+
}
|
206 |
+
})
|
207 |
+
|
208 |
+
result.sort(key=lambda x: x["graph_info"]["created"], reverse=True)
|
209 |
+
return result
|
210 |
+
|
211 |
+
def select_graph(self, graph_id: str) -> bool:
|
212 |
+
"""Select a specific graph as the current working graph."""
|
213 |
+
if graph_id in self.graphs:
|
214 |
+
self.current_graph_id = graph_id
|
215 |
+
return True
|
216 |
+
return False
|
217 |
+
|
218 |
+
def create_new_graph(self) -> str:
|
219 |
+
"""Create a new graph instance and its ID."""
|
220 |
+
graph_id = str(uuid.uuid4())
|
221 |
+
graph = rx.PyDiGraph()
|
222 |
+
node_map = {}
|
223 |
+
metadata = {
|
224 |
+
"id": graph_id,
|
225 |
+
"created": time.time(),
|
226 |
+
"updated": time.time()
|
227 |
+
}
|
228 |
+
self.graphs[graph_id] = {"graph": graph, "node_map": node_map, "metadata": metadata}
|
229 |
+
self.current_graph_id = graph_id
|
230 |
+
|
231 |
+
return graph_id
|
232 |
+
|
233 |
+
def load_graph(self, node_id: str) -> bool:
|
234 |
+
"""Load an existing graph structure from memory based on a node ID."""
|
235 |
+
|
236 |
+
for gid, data in self.graphs.items():
|
237 |
+
if node_id in data["node_map"]:
|
238 |
+
self.current_graph_id = gid
|
239 |
+
|
240 |
+
for n_id in data["node_map"].keys():
|
241 |
+
if "SQ" in n_id:
|
242 |
+
num = int(''.join(filter(str.isdigit, n_id)) or 0)
|
243 |
+
self.node_counter = max(self.node_counter, num)
|
244 |
+
elif "SSQ" in n_id:
|
245 |
+
num = int(''.join(filter(str.isdigit, n_id)) or 0)
|
246 |
+
self.sub_node_counter = max(self.sub_node_counter, num)
|
247 |
+
|
248 |
+
self.node_counter += 1
|
249 |
+
self.sub_node_counter += 1
|
250 |
+
graph = data["graph"]
|
251 |
+
node_map = data["node_map"]
|
252 |
+
|
253 |
+
for (u, v), edge_data in zip(graph.edge_list(), graph.edges()):
|
254 |
+
if edge_data.get("type") == "logical":
|
255 |
+
source_id = graph.get_node_data(u).get("id")
|
256 |
+
target_id = graph.get_node_data(v).get("id")
|
257 |
+
connection = tuple(sorted([source_id, target_id]))
|
258 |
+
self.cross_connections.add(connection)
|
259 |
+
|
260 |
+
print(f"Successfully loaded graph. Current counters - Node: {self.node_counter}, Sub: {self.sub_node_counter}")
|
261 |
+
return True
|
262 |
+
|
263 |
+
print(f"Graph with node_id {node_id} not found.")
|
264 |
+
|
265 |
+
return False
|
266 |
+
|
267 |
+
async def modify_graph(self, new_query: str, similar_node_id: str, session_id: str = None):
|
268 |
+
"""Modify an existing graph structure by integrating a new query."""
|
269 |
+
graph_data = self._get_current_graph_data()
|
270 |
+
graph = graph_data["graph"]
|
271 |
+
node_map = graph_data["node_map"]
|
272 |
+
|
273 |
+
async def add_as_sibling(node_id: str, query: str):
|
274 |
+
if node_id not in node_map:
|
275 |
+
raise ValueError(f"Node {node_id} not found")
|
276 |
+
|
277 |
+
idx = node_map[node_id]
|
278 |
+
in_edges = graph.in_edges(idx)
|
279 |
+
|
280 |
+
if not in_edges:
|
281 |
+
raise ValueError(f"No parent found for node {node_id}")
|
282 |
+
|
283 |
+
parent_idx = in_edges[0][0]
|
284 |
+
parent_data = graph.get_node_data(parent_idx)
|
285 |
+
parent_id = parent_data.get("id")
|
286 |
+
|
287 |
+
if "SQ" in node_id:
|
288 |
+
self.node_counter += 1
|
289 |
+
new_node_id = f"SQ{self.node_counter}"
|
290 |
+
else:
|
291 |
+
self.sub_node_counter += 1
|
292 |
+
new_node_id = f"SSQ{self.sub_node_counter}"
|
293 |
+
|
294 |
+
self.add_node(new_node_id, query, role="independent")
|
295 |
+
self.add_edge(parent_id, new_node_id, relationship_type=in_edges[0][2].get("type"))
|
296 |
+
|
297 |
+
return new_node_id
|
298 |
+
|
299 |
+
async def add_as_child(node_id: str, query: str):
|
300 |
+
if "SQ" in node_id:
|
301 |
+
self.sub_node_counter += 1
|
302 |
+
new_node_id = f"SSQ{self.sub_node_counter}"
|
303 |
+
else:
|
304 |
+
self.node_counter += 1
|
305 |
+
new_node_id = f"SQ{self.node_counter}"
|
306 |
+
|
307 |
+
self.add_node(new_node_id, query, role="dependent")
|
308 |
+
self.add_edge(node_id, new_node_id, relationship_type="logical")
|
309 |
+
|
310 |
+
return new_node_id
|
311 |
+
|
312 |
+
def collect_graph_context() -> list:
|
313 |
+
"""Collect context from existing graph nodes."""
|
314 |
+
graph_data = self._get_current_graph_data()
|
315 |
+
graph = graph_data["graph"]
|
316 |
+
node_map = graph_data["node_map"]
|
317 |
+
nodes = []
|
318 |
+
|
319 |
+
for n_id, idx in node_map.items():
|
320 |
+
if n_id == self.root_node_id:
|
321 |
+
continue
|
322 |
+
node_data = graph.get_node_data(idx)
|
323 |
+
nodes.append({
|
324 |
+
"id": node_data.get("id"),
|
325 |
+
"query": node_data.get("query"),
|
326 |
+
"role": node_data.get("role")
|
327 |
+
})
|
328 |
+
|
329 |
+
nodes.sort(key=lambda x: (0 if x["id"].startswith("SQ") else (1 if x["id"].startswith("SSQ") else 2), x["id"]))
|
330 |
+
level_queries = {}
|
331 |
+
current_sq = None
|
332 |
+
|
333 |
+
for node in nodes:
|
334 |
+
node_id = node["id"]
|
335 |
+
if node_id.startswith("SQ"):
|
336 |
+
current_sq = node_id
|
337 |
+
|
338 |
+
if current_sq not in level_queries:
|
339 |
+
level_queries[current_sq] = {
|
340 |
+
"originalquery": node["query"],
|
341 |
+
"subqueries": []
|
342 |
+
}
|
343 |
+
level_queries[current_sq]["subqueries"].append({
|
344 |
+
"subquery": node["query"],
|
345 |
+
"role": node["role"],
|
346 |
+
"dependson": []
|
347 |
+
})
|
348 |
+
|
349 |
+
elif node_id.startswith("SSQ") and current_sq:
|
350 |
+
level_queries[current_sq]["subqueries"].append({
|
351 |
+
"subquery": node["query"],
|
352 |
+
"role": node["role"],
|
353 |
+
"dependson": []
|
354 |
+
})
|
355 |
+
|
356 |
+
return list(level_queries.values())
|
357 |
+
|
358 |
+
if similar_node_id not in node_map:
|
359 |
+
raise Exception(f"Node {similar_node_id} not found")
|
360 |
+
|
361 |
+
similar_node_data = graph.get_node_data(node_map[similar_node_id])
|
362 |
+
has_parent = len(graph.in_edges(node_map[similar_node_id])) > 0
|
363 |
+
|
364 |
+
context = collect_graph_context()
|
365 |
+
if similar_node_data.get("role") == "independent":
|
366 |
+
if has_parent:
|
367 |
+
new_node_id = await add_as_sibling(similar_node_id, new_query)
|
368 |
+
else:
|
369 |
+
new_node_id = await add_as_child(similar_node_id, new_query)
|
370 |
+
else:
|
371 |
+
new_node_id = await add_as_child(similar_node_id, new_query)
|
372 |
+
|
373 |
+
await self.build_graph(
|
374 |
+
query=new_query,
|
375 |
+
parent_node_id=new_node_id,
|
376 |
+
depth=1 if "SQ" in new_node_id else 2,
|
377 |
+
context=context,
|
378 |
+
session_id=session_id
|
379 |
+
)
|
380 |
+
|
381 |
+
async def build_graph(self, query: str, data: str = None, parent_node_id: str = None,
|
382 |
+
depth: int = 0, threshold: float = 0.8, recurse: bool = True,
|
383 |
+
context: list = None, session_id: str = None, max_tokens_allowed: int = 128000):
|
384 |
+
"""Build a new graph structure in memory."""
|
385 |
+
async def process_node(node_id: str, sub_query: str, session_id: str, future: asyncio.Future, depth=depth, max_tokens_allowed=max_tokens_allowed):
|
386 |
+
try:
|
387 |
+
optimized_query = await self.search_engine.generate_optimized_query(sub_query)
|
388 |
+
results = await self.search_engine.search(
|
389 |
+
query=optimized_query,
|
390 |
+
num_results=10,
|
391 |
+
exclude_filetypes=["pdf"]
|
392 |
+
)
|
393 |
+
await self.emit_event("search_results_fetched", {
|
394 |
+
"node_id": node_id,
|
395 |
+
"sub_query": sub_query,
|
396 |
+
"optimized_query": optimized_query,
|
397 |
+
"search_results": results
|
398 |
+
})
|
399 |
+
filtered_urls = await self.search_engine.filter_urls(
|
400 |
+
sub_query,
|
401 |
+
"extensive research dynamic structure",
|
402 |
+
results
|
403 |
+
)
|
404 |
+
await self.emit_event("search_results_filtered", {
|
405 |
+
"node_id": node_id,
|
406 |
+
"sub_query": sub_query,
|
407 |
+
"filtered_urls": filtered_urls
|
408 |
+
})
|
409 |
+
urls = [result.get('link', 'No URL') for result in filtered_urls]
|
410 |
+
search_contents = await self.custom_crawler.fetch_page_contents(
|
411 |
+
urls,
|
412 |
+
sub_query,
|
413 |
+
session_id=session_id,
|
414 |
+
max_attempts=1,
|
415 |
+
timeout=30
|
416 |
+
)
|
417 |
+
await self.emit_event("search_contents_fetched", {
|
418 |
+
"node_id": node_id,
|
419 |
+
"sub_query": sub_query,
|
420 |
+
"contents": search_contents
|
421 |
+
})
|
422 |
+
|
423 |
+
contents = ""
|
424 |
+
for k, content in enumerate(search_contents, 1):
|
425 |
+
if isinstance(content, Exception):
|
426 |
+
print(f"Error fetching content: {content}")
|
427 |
+
elif content:
|
428 |
+
contents += f"Document {k}:\n{content}\n\n"
|
429 |
+
|
430 |
+
if contents.strip():
|
431 |
+
if depth == 0:
|
432 |
+
await self.emit_event("sub_query_processed", {
|
433 |
+
"node_id": node_id,
|
434 |
+
"sub_query": sub_query,
|
435 |
+
"contents": contents
|
436 |
+
})
|
437 |
+
|
438 |
+
token_count = self.llm.get_num_tokens(contents)
|
439 |
+
if token_count > max_tokens_allowed:
|
440 |
+
contents = await self.chunking.chunker(
|
441 |
+
text=contents,
|
442 |
+
query=sub_query,
|
443 |
+
max_tokens=max_tokens_allowed
|
444 |
+
)
|
445 |
+
print(f"Number of tokens in the answer: {token_count}")
|
446 |
+
print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}")
|
447 |
+
else:
|
448 |
+
if depth == 0:
|
449 |
+
await self.emit_event("sub_query_failed", {
|
450 |
+
"node_id": node_id,
|
451 |
+
"sub_query": sub_query,
|
452 |
+
"contents": contents
|
453 |
+
})
|
454 |
+
|
455 |
+
graph_data = self._get_current_graph_data()
|
456 |
+
graph = graph_data["graph"]
|
457 |
+
node_map = graph_data["node_map"]
|
458 |
+
|
459 |
+
if node_id in node_map:
|
460 |
+
idx = node_map[node_id]
|
461 |
+
node_data = graph.get_node_data(idx)
|
462 |
+
node_data["data"] = contents
|
463 |
+
future.set_result(contents)
|
464 |
+
except Exception as e:
|
465 |
+
print(f"Error processing node {node_id}: {str(e)}")
|
466 |
+
future.set_exception(e)
|
467 |
+
raise
|
468 |
+
|
469 |
+
async def process_dependent_node(node_id: str, sub_query: str, depth, dep_futures: list, future):
|
470 |
+
try:
|
471 |
+
dep_data = [await f for f in dep_futures]
|
472 |
+
modified_query = await self.query_processor.modify_query(
|
473 |
+
sub_query,
|
474 |
+
dep_data
|
475 |
+
)
|
476 |
+
loop = asyncio.get_running_loop()
|
477 |
+
embedding = await loop.run_in_executor(
|
478 |
+
self.executor,
|
479 |
+
self.model.encode,
|
480 |
+
modified_query
|
481 |
+
)
|
482 |
+
graph_data = self._get_current_graph_data()
|
483 |
+
graph = graph_data["graph"]
|
484 |
+
node_map = graph_data["node_map"]
|
485 |
+
|
486 |
+
if node_id in node_map:
|
487 |
+
idx = node_map[node_id]
|
488 |
+
node_data = graph.get_node_data(idx)
|
489 |
+
node_data["query"] = modified_query
|
490 |
+
node_data["embedding"] = embedding.tolist() if hasattr(embedding, "tolist") else embedding
|
491 |
+
try:
|
492 |
+
if not future.done():
|
493 |
+
await process_node(node_id, modified_query, session_id, future, depth, max_tokens_allowed)
|
494 |
+
except Exception as e:
|
495 |
+
if not future.done():
|
496 |
+
future.set_exception(e)
|
497 |
+
raise
|
498 |
+
except Exception as e:
|
499 |
+
print(f"Error processing dependent node {node_id}: {str(e)}")
|
500 |
+
if not future.done():
|
501 |
+
future.set_exception(e)
|
502 |
+
raise
|
503 |
+
|
504 |
+
def create_cross_connections():
|
505 |
+
try:
|
506 |
+
relationships = self.get_node_relationships(relationship_type='logical')
|
507 |
+
|
508 |
+
for current_node_id, edges in relationships.items():
|
509 |
+
graph_data = self._get_current_graph_data()
|
510 |
+
graph = graph_data["graph"]
|
511 |
+
node_map = graph_data["node_map"]
|
512 |
+
|
513 |
+
if current_node_id not in node_map:
|
514 |
+
continue
|
515 |
+
|
516 |
+
idx = node_map[current_node_id]
|
517 |
+
node_data = graph.get_node_data(idx)
|
518 |
+
node_role = (node_data.get("role") or "").lower()
|
519 |
+
|
520 |
+
if node_role == 'dependent':
|
521 |
+
for source_id, target_id, edge_data in edges['in_edges']:
|
522 |
+
if not source_id or source_id == self.root_node_id:
|
523 |
+
continue
|
524 |
+
|
525 |
+
connection = tuple(sorted([current_node_id, source_id]))
|
526 |
+
if connection not in self.cross_connections:
|
527 |
+
if not self.edge_exists(source_id, current_node_id):
|
528 |
+
print(f"Adding cross-connection edge between {source_id} and {current_node_id}")
|
529 |
+
self.add_edge(source_id, current_node_id, weight=edge_data.get('weight', 1.0), relationship_type='logical')
|
530 |
+
self.cross_connections.add(connection)
|
531 |
+
|
532 |
+
for source_id, target_id, edge_data in edges['out_edges']:
|
533 |
+
if not target_id or target_id == self.root_node_id:
|
534 |
+
continue
|
535 |
+
|
536 |
+
connection = tuple(sorted([current_node_id, target_id]))
|
537 |
+
if connection not in self.cross_connections:
|
538 |
+
if not self.edge_exists(current_node_id, target_id):
|
539 |
+
print(f"Adding cross-connection edge between {current_node_id} and {target_id}")
|
540 |
+
self.add_edge(current_node_id, target_id, weight=edge_data.get('weight', 1.0), relationship_type='logical')
|
541 |
+
self.cross_connections.add(connection)
|
542 |
+
except Exception as e:
|
543 |
+
print(f"Error creating cross connections: {str(e)}")
|
544 |
+
raise
|
545 |
+
|
546 |
+
if depth > 1:
|
547 |
+
return
|
548 |
+
|
549 |
+
if context is None:
|
550 |
+
context = []
|
551 |
+
node_data_futures = {}
|
552 |
+
|
553 |
+
if parent_node_id is None:
|
554 |
+
self.add_node(self.root_node_id, query, data)
|
555 |
+
parent_node_id = self.root_node_id
|
556 |
+
|
557 |
+
intent = await self.query_processor.get_query_intent(query)
|
558 |
+
|
559 |
+
if depth == 0:
|
560 |
+
response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent)
|
561 |
+
else:
|
562 |
+
response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent, context)
|
563 |
+
|
564 |
+
if response_data:
|
565 |
+
context.append(response_data)
|
566 |
+
|
567 |
+
if len(sub_queries) > 1 and sub_queries[0] != query:
|
568 |
+
sub_query_ids = []
|
569 |
+
pre_req_nodes = {}
|
570 |
+
|
571 |
+
for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)):
|
572 |
+
if depth == 0:
|
573 |
+
await self.emit_event("sub_query_created", {
|
574 |
+
"depth": depth,
|
575 |
+
"sub_query": sub_query,
|
576 |
+
"role": role,
|
577 |
+
"dependency": dependency,
|
578 |
+
"parent_node_id": parent_node_id,
|
579 |
+
})
|
580 |
+
|
581 |
+
if depth == 0:
|
582 |
+
self.node_counter += 1
|
583 |
+
sub_node_id = f"SQ{self.node_counter}"
|
584 |
+
else:
|
585 |
+
self.sub_node_counter += 1
|
586 |
+
sub_node_id = f"SSQ{self.sub_node_counter}"
|
587 |
+
|
588 |
+
sub_query_ids.append(sub_node_id)
|
589 |
+
self.add_node(sub_node_id, sub_query, role=role)
|
590 |
+
future = asyncio.Future()
|
591 |
+
node_data_futures[sub_node_id] = future
|
592 |
+
|
593 |
+
if role.lower() in ['pre-requisite', 'prerequisite']:
|
594 |
+
pre_req_nodes[idx] = sub_node_id
|
595 |
+
|
596 |
+
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
|
597 |
+
self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical')
|
598 |
+
|
599 |
+
elif role.lower() == 'dependent':
|
600 |
+
if isinstance(dependency, list) and (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)):
|
601 |
+
print(f"Dependency: {dependency}")
|
602 |
+
prev_deps, current_deps = dependency
|
603 |
+
|
604 |
+
if context and prev_deps not in [None, []]:
|
605 |
+
for dep_idx in prev_deps:
|
606 |
+
if dep_idx is not None:
|
607 |
+
matching_nodes = self.find_nodes_by_properties(query=dep_idx)
|
608 |
+
|
609 |
+
if matching_nodes not in [None, []]:
|
610 |
+
dep_node_id = matching_nodes[0].get('node_id')
|
611 |
+
score = matching_nodes[0].get('score', 0)
|
612 |
+
|
613 |
+
if score >= 0.9:
|
614 |
+
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
|
615 |
+
|
616 |
+
if current_deps not in [None, []]:
|
617 |
+
for dep_idx in current_deps:
|
618 |
+
if dep_idx < len(sub_queries):
|
619 |
+
dep_node_id = sub_query_ids[dep_idx]
|
620 |
+
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
|
621 |
+
else:
|
622 |
+
raise ValueError(f"Invalid dependency index: {dep_idx}")
|
623 |
+
|
624 |
+
elif len(dependency) > 0:
|
625 |
+
for dep_idx in dependency:
|
626 |
+
if dep_idx < len(sub_queries):
|
627 |
+
dep_node_id = sub_query_ids[dep_idx]
|
628 |
+
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
|
629 |
+
else:
|
630 |
+
raise ValueError(f"Invalid dependency index: {dep_idx}")
|
631 |
+
else:
|
632 |
+
raise ValueError(f"Invalid dependency: {dependency}")
|
633 |
+
else:
|
634 |
+
raise ValueError(f"Unexpected role: {role}")
|
635 |
+
|
636 |
+
tasks = []
|
637 |
+
for idx in range(len(sub_queries)):
|
638 |
+
node_id = sub_query_ids[idx]
|
639 |
+
future = node_data_futures[node_id]
|
640 |
+
|
641 |
+
if roles[idx].lower() in ('pre-requisite', 'prerequisite', 'independent'):
|
642 |
+
tasks.append(process_node(node_id, sub_queries[idx], session_id, future, depth, max_tokens_allowed))
|
643 |
+
|
644 |
+
for idx in range(len(sub_queries)):
|
645 |
+
node_id = sub_query_ids[idx]
|
646 |
+
future = node_data_futures[node_id]
|
647 |
+
|
648 |
+
if roles[idx].lower() == 'dependent':
|
649 |
+
dep_futures = []
|
650 |
+
|
651 |
+
if isinstance(dependencies[idx], list) and len(dependencies[idx]) == 2:
|
652 |
+
prev_deps, current_deps = dependencies[idx]
|
653 |
+
|
654 |
+
if context and prev_deps not in [None, []]:
|
655 |
+
for context_idx, context_data in enumerate(context):
|
656 |
+
if isinstance(prev_deps, list) and context_idx < len(prev_deps):
|
657 |
+
context_dep = prev_deps[context_idx]
|
658 |
+
|
659 |
+
if context_dep is not None and isinstance(context_data, dict) and 'subqueries' in context_data:
|
660 |
+
|
661 |
+
if context_dep < len(context_data['subqueries']):
|
662 |
+
dep_query = context_data['subqueries'][context_dep]['subquery']
|
663 |
+
matching_nodes = self.find_nodes_by_properties(query=dep_query)
|
664 |
+
|
665 |
+
if matching_nodes not in [None, []]:
|
666 |
+
dep_node_id = matching_nodes[0].get('node_id', None)
|
667 |
+
score = float(matching_nodes[0].get('score', 0))
|
668 |
+
|
669 |
+
if score == 1.0 and dep_node_id in node_data_futures:
|
670 |
+
dep_futures.append(node_data_futures[dep_node_id])
|
671 |
+
|
672 |
+
elif isinstance(prev_deps, int):
|
673 |
+
if prev_deps < len(context_data['subqueries']):
|
674 |
+
dep_query = context_data['subqueries'][prev_deps]['subquery']
|
675 |
+
matching_nodes = self.find_nodes_by_properties(query=dep_query)
|
676 |
+
|
677 |
+
if matching_nodes not in [None, []]:
|
678 |
+
dep_node_id = matching_nodes[0].get('node_id', None)
|
679 |
+
score = matching_nodes[0].get('score', 0)
|
680 |
+
|
681 |
+
if score == 1.0 and dep_node_id in node_data_futures:
|
682 |
+
dep_futures.append(node_data_futures[dep_node_id])
|
683 |
+
|
684 |
+
if current_deps not in [None, []]:
|
685 |
+
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
|
686 |
+
|
687 |
+
for dep_idx in current_deps_list:
|
688 |
+
if dep_idx < len(sub_queries):
|
689 |
+
dep_node_id = sub_query_ids[dep_idx]
|
690 |
+
if dep_node_id in node_data_futures:
|
691 |
+
dep_futures.append(node_data_futures[dep_node_id])
|
692 |
+
|
693 |
+
tasks.append(process_dependent_node(node_id, sub_queries[idx], depth, dep_futures, future))
|
694 |
+
|
695 |
+
if depth == 0:
|
696 |
+
await self.emit_event("search_process_started", {
|
697 |
+
"depth": depth,
|
698 |
+
"sub_queries": sub_queries,
|
699 |
+
"roles": roles
|
700 |
+
})
|
701 |
+
|
702 |
+
await asyncio.gather(*tasks)
|
703 |
+
|
704 |
+
if recurse:
|
705 |
+
recursion_tasks = []
|
706 |
+
|
707 |
+
for idx, sub_query in enumerate(sub_queries):
|
708 |
+
try:
|
709 |
+
sub_node_id = sub_query_ids[idx]
|
710 |
+
recursion_tasks.append(
|
711 |
+
self.build_graph(
|
712 |
+
query=sub_query,
|
713 |
+
parent_node_id=sub_node_id,
|
714 |
+
depth=depth + 1,
|
715 |
+
threshold=threshold,
|
716 |
+
recurse=recurse,
|
717 |
+
context=context,
|
718 |
+
session_id=session_id
|
719 |
+
)
|
720 |
+
)
|
721 |
+
except Exception as e:
|
722 |
+
print(f"Failed to create recursion task for sub-query {sub_query}: {e}")
|
723 |
+
continue
|
724 |
+
|
725 |
+
if recursion_tasks:
|
726 |
+
try:
|
727 |
+
await asyncio.gather(*recursion_tasks)
|
728 |
+
except Exception as e:
|
729 |
+
raise Exception(f"Error during recursive processing: {e}")
|
730 |
+
|
731 |
+
if depth == 0:
|
732 |
+
print("Graph building complete, processing final tasks...")
|
733 |
+
await self.emit_event("search_process_completed", {
|
734 |
+
"depth": depth,
|
735 |
+
"sub_queries": sub_queries,
|
736 |
+
"roles": roles
|
737 |
+
})
|
738 |
+
|
739 |
+
create_cross_connections()
|
740 |
+
print("All cross-connections have been created!")
|
741 |
+
print(f"Adding similarity edges with threshold {threshold}")
|
742 |
+
|
743 |
+
graph_data = self._get_current_graph_data()
|
744 |
+
node_map = graph_data["node_map"]
|
745 |
+
all_node_ids = list(node_map.keys())
|
746 |
+
|
747 |
+
for i, node1 in enumerate(all_node_ids):
|
748 |
+
for node2 in all_node_ids[i+1:]:
|
749 |
+
if not self.edge_exists(node1, node2):
|
750 |
+
self.add_edge_based_on_similarity_and_relevance(node1, node2, query, threshold)
|
751 |
+
|
752 |
+
print("All similarity edges have been added!")
|
753 |
+
|
754 |
+
async def process_graph(
|
755 |
+
self,
|
756 |
+
query: str,
|
757 |
+
data: str = None,
|
758 |
+
similarity_threshold: float = 0.8,
|
759 |
+
relevance_threshold: float = 0.7,
|
760 |
+
sub_sub_queries: bool = True,
|
761 |
+
session_id: str = None,
|
762 |
+
max_tokens_allowed: int = 128000
|
763 |
+
):
|
764 |
+
"""Process a query and manage graph creation/modification."""
|
765 |
+
def check_query_similarity(new_query: str, similarity_threshold: float = 0.8) -> Dict[str, Any]:
|
766 |
+
if self.current_graph_id is None:
|
767 |
+
raise Exception("Error: No current graph ID. Cannot check query similarity.")
|
768 |
+
|
769 |
+
graph_data = self._get_current_graph_data()
|
770 |
+
graph = graph_data["graph"]
|
771 |
+
node_map = graph_data["node_map"]
|
772 |
+
similarities = []
|
773 |
+
|
774 |
+
if not node_map:
|
775 |
+
return {"should_create_new": True}
|
776 |
+
|
777 |
+
for node_id, idx in node_map.items():
|
778 |
+
node_data = graph.get_node_data(idx)
|
779 |
+
|
780 |
+
if not node_data.get("query"):
|
781 |
+
continue
|
782 |
+
|
783 |
+
similarity = self.calculate_query_similarity(new_query, node_data.get("query"))
|
784 |
+
if similarity >= similarity_threshold:
|
785 |
+
similarities.append({
|
786 |
+
"node_id": node_id,
|
787 |
+
"query": node_data.get("query"),
|
788 |
+
"score": similarity,
|
789 |
+
"role": node_data.get("role")
|
790 |
+
})
|
791 |
+
|
792 |
+
if not similarities:
|
793 |
+
print(f"No similar queries found above threshold {similarity_threshold}")
|
794 |
+
return {"should_create_new": True}
|
795 |
+
|
796 |
+
best_match = max(similarities, key=lambda x: x["score"])
|
797 |
+
|
798 |
+
rel_type = "root"
|
799 |
+
if "SSQ" in best_match["node_id"]:
|
800 |
+
rel_type = "sub-sub"
|
801 |
+
|
802 |
+
elif "SQ" in best_match["node_id"]:
|
803 |
+
rel_type = "sub"
|
804 |
+
|
805 |
+
return {
|
806 |
+
"most_similar_query": best_match["query"],
|
807 |
+
"similarity_score": best_match["score"],
|
808 |
+
"relationship_type": rel_type,
|
809 |
+
"node_id": best_match["node_id"],
|
810 |
+
"should_create_new": best_match["score"] < similarity_threshold
|
811 |
+
}
|
812 |
+
try:
|
813 |
+
graphs = self.get_graphs()
|
814 |
+
|
815 |
+
if not graphs:
|
816 |
+
print("No existing graphs found. Creating new graph.")
|
817 |
+
self.create_new_graph()
|
818 |
+
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"})
|
819 |
+
await self.build_graph(
|
820 |
+
query=query,
|
821 |
+
data=data,
|
822 |
+
threshold=relevance_threshold,
|
823 |
+
recurse=sub_sub_queries,
|
824 |
+
session_id=session_id,
|
825 |
+
max_tokens_allowed=max_tokens_allowed
|
826 |
+
)
|
827 |
+
gc.collect()
|
828 |
+
self.prune_edges()
|
829 |
+
self.update_pagerank()
|
830 |
+
self.verify_graph_integrity()
|
831 |
+
self.verify_graph_consistency()
|
832 |
+
return
|
833 |
+
|
834 |
+
max_similarity = 0
|
835 |
+
most_similar_graph = None
|
836 |
+
consolidated_graphs = {}
|
837 |
+
|
838 |
+
for graph_obj in graphs:
|
839 |
+
graph_info = graph_obj.get("graph_info")
|
840 |
+
if not graph_info:
|
841 |
+
continue
|
842 |
+
|
843 |
+
graph_id = graph_info.get("graph_id")
|
844 |
+
|
845 |
+
if not graph_id:
|
846 |
+
continue
|
847 |
+
|
848 |
+
if graph_id not in consolidated_graphs:
|
849 |
+
consolidated_graphs[graph_id] = {
|
850 |
+
"graph_id": graph_id,
|
851 |
+
"nodes": []
|
852 |
+
}
|
853 |
+
|
854 |
+
if graph_info.get("nodes"):
|
855 |
+
consolidated_graphs[graph_id]["nodes"].extend(graph_info["nodes"])
|
856 |
+
|
857 |
+
for graph_id, graph_data in consolidated_graphs.items():
|
858 |
+
nodes = graph_data["nodes"]
|
859 |
+
|
860 |
+
for node in nodes:
|
861 |
+
if node.get("query"):
|
862 |
+
similarity = self.calculate_query_similarity(query, node["query"])
|
863 |
+
|
864 |
+
if node.get("id", "").startswith("SQ"):
|
865 |
+
asyncio.create_task(self.emit_event("retrieved_sub_query", {
|
866 |
+
"sub_query": node["query"]
|
867 |
+
}))
|
868 |
+
|
869 |
+
if similarity > max_similarity:
|
870 |
+
max_similarity = similarity
|
871 |
+
most_similar_graph = graph_id
|
872 |
+
|
873 |
+
if max_similarity >= similarity_threshold:
|
874 |
+
print(f"Found similar query with score {round(max_similarity, 2)}")
|
875 |
+
self.current_graph_id = most_similar_graph
|
876 |
+
|
877 |
+
if round(max_similarity, 2) == 1.0:
|
878 |
+
print("Loading and using existing graph")
|
879 |
+
await self.emit_event("graph_operation", {"operation_type": "loading_existing_graph"})
|
880 |
+
success = self.load_graph(self.root_node_id)
|
881 |
+
|
882 |
+
if not success:
|
883 |
+
raise Exception("Failed to load existing graph")
|
884 |
+
|
885 |
+
else:
|
886 |
+
print("Checking for node-level similarity...")
|
887 |
+
similarity_info = check_query_similarity(
|
888 |
+
query,
|
889 |
+
similarity_threshold
|
890 |
+
)
|
891 |
+
|
892 |
+
if similarity_info["relationship_type"] in ["sub", "sub-sub"]:
|
893 |
+
print(f"Most Similar Query: {similarity_info['most_similar_query']}")
|
894 |
+
print("Modifying existing graph structure")
|
895 |
+
await self.emit_event("graph_operation", {"operation_type": "modifying_existing_graph"})
|
896 |
+
await self.modify_graph(
|
897 |
+
query,
|
898 |
+
similarity_info["node_id"],
|
899 |
+
session_id=session_id
|
900 |
+
)
|
901 |
+
gc.collect()
|
902 |
+
self.prune_edges()
|
903 |
+
self.update_pagerank()
|
904 |
+
self.verify_graph_integrity()
|
905 |
+
self.verify_graph_consistency()
|
906 |
+
|
907 |
+
else:
|
908 |
+
print(f"Creating new graph for query: {query}")
|
909 |
+
self.create_new_graph()
|
910 |
+
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"})
|
911 |
+
await self.build_graph(
|
912 |
+
query=query,
|
913 |
+
data=data,
|
914 |
+
threshold=relevance_threshold,
|
915 |
+
recurse=sub_sub_queries,
|
916 |
+
session_id=session_id,
|
917 |
+
max_tokens_allowed=max_tokens_allowed
|
918 |
+
)
|
919 |
+
gc.collect()
|
920 |
+
self.prune_edges()
|
921 |
+
self.update_pagerank()
|
922 |
+
self.verify_graph_integrity()
|
923 |
+
self.verify_graph_consistency()
|
924 |
+
except Exception as e:
|
925 |
+
print(f"Error in process_graph: {str(e)}")
|
926 |
+
raise
|
927 |
+
|
928 |
+
def add_edge_based_on_similarity_and_relevance(self, node1_id: str, node2_id: str, query: str, threshold: float = 0.8):
|
929 |
+
"""Add edges based on node similarity and relevance."""
|
930 |
+
graph_data = self._get_current_graph_data()
|
931 |
+
graph = graph_data["graph"]
|
932 |
+
node_map = graph_data["node_map"]
|
933 |
+
|
934 |
+
if node1_id not in node_map or node2_id not in node_map:
|
935 |
+
return
|
936 |
+
|
937 |
+
idx1 = node_map[node1_id]
|
938 |
+
idx2 = node_map[node2_id]
|
939 |
+
node1_data = graph.get_node_data(idx1)
|
940 |
+
node2_data = graph.get_node_data(idx2)
|
941 |
+
|
942 |
+
if not all([node1_data.get("embedding"), node2_data.get("embedding"), node1_data.get("data"), node2_data.get("data")]):
|
943 |
+
return
|
944 |
+
|
945 |
+
similarity = self.cosine_similarity(node1_data["embedding"], node2_data["embedding"])
|
946 |
+
query_relevance1 = self.calculate_relevance(query, node1_data["data"])
|
947 |
+
query_relevance2 = self.calculate_relevance(query, node2_data["data"])
|
948 |
+
node_relevance = self.calculate_relevance(node1_data["data"], node2_data["data"])
|
949 |
+
weight = (similarity + query_relevance1 + query_relevance2 + node_relevance) / 4
|
950 |
+
|
951 |
+
if weight >= threshold:
|
952 |
+
self.add_edge(node1_id, node2_id, weight=weight, relationship_type='similarity_and_relevance')
|
953 |
+
print(f"Added edge between {node1_id} and {node2_id} with type similarity_and_relevance and weight {weight}")
|
954 |
+
|
955 |
+
def calculate_relevance(self, data1: str, data2: str) -> float:
|
956 |
+
"""Calculate relevance between two data strings."""
|
957 |
+
try:
|
958 |
+
if not data1 or not data2:
|
959 |
+
return 0.0
|
960 |
+
|
961 |
+
P, R, F1 = self.scorer.score([data1], [data2])
|
962 |
+
return F1.mean().item()
|
963 |
+
except Exception as e:
|
964 |
+
print(f"Error calculating relevance: {str(e)}")
|
965 |
+
return 0.0
|
966 |
+
|
967 |
+
def calculate_query_similarity(self, query1: str, query2: str) -> float:
|
968 |
+
"""Calculate similarity between two queries."""
|
969 |
+
try:
|
970 |
+
embedding1 = self.model.encode(query1).tolist()
|
971 |
+
embedding2 = self.model.encode(query2).tolist()
|
972 |
+
return self.cosine_similarity(embedding1, embedding2)
|
973 |
+
except Exception as e:
|
974 |
+
print(f"Error calculating query similarity: {str(e)}")
|
975 |
+
return 0.0
|
976 |
+
|
977 |
+
def get_similarities_and_relevance(self, threshold: float = 0.8) -> list:
|
978 |
+
"""Get similarities and relevance between nodes."""
|
979 |
+
try:
|
980 |
+
graph_data = self._get_current_graph_data()
|
981 |
+
graph = graph_data["graph"]
|
982 |
+
node_map = graph_data["node_map"]
|
983 |
+
nodes = []
|
984 |
+
|
985 |
+
for node_id, idx in node_map.items():
|
986 |
+
node_data = graph.get_node_data(idx)
|
987 |
+
nodes.append({
|
988 |
+
"id": node_data.get("id"),
|
989 |
+
"embedding": node_data.get("embedding"),
|
990 |
+
"data": node_data.get("data")
|
991 |
+
})
|
992 |
+
|
993 |
+
similarities = []
|
994 |
+
for i, node1 in enumerate(nodes):
|
995 |
+
for node2 in nodes[i + 1:]:
|
996 |
+
similarity = self.cosine_similarity(node1["embedding"], node2["embedding"])
|
997 |
+
relevance = self.calculate_relevance(node1["data"], node2["data"])
|
998 |
+
weight = (similarity + relevance) / 2
|
999 |
+
|
1000 |
+
if weight >= threshold:
|
1001 |
+
similarities.append({
|
1002 |
+
'node1': node1["id"],
|
1003 |
+
'node2': node2["id"],
|
1004 |
+
'similarity': similarity,
|
1005 |
+
'relevance': relevance,
|
1006 |
+
'weight': weight
|
1007 |
+
})
|
1008 |
+
|
1009 |
+
return similarities
|
1010 |
+
except Exception as e:
|
1011 |
+
print(f"Error getting similarities and relevance: {str(e)}")
|
1012 |
+
return []
|
1013 |
+
|
1014 |
+
def get_node_relationships(self, node_id=None, depth=None, role=None, relationship_type=None):
|
1015 |
+
"""Get relationships between nodes with filtering options."""
|
1016 |
+
graph_data = self._get_current_graph_data()
|
1017 |
+
graph = graph_data["graph"]
|
1018 |
+
node_map = graph_data["node_map"]
|
1019 |
+
relationships = {}
|
1020 |
+
|
1021 |
+
for n_id, idx in node_map.items():
|
1022 |
+
if n_id == self.root_node_id:
|
1023 |
+
continue
|
1024 |
+
|
1025 |
+
node_data = graph.get_node_data(idx)
|
1026 |
+
|
1027 |
+
if node_id and n_id != node_id:
|
1028 |
+
continue
|
1029 |
+
|
1030 |
+
if role and node_data.get("role") != role:
|
1031 |
+
continue
|
1032 |
+
|
1033 |
+
in_edges = []
|
1034 |
+
for u, v, edge_data in graph.in_edges(idx):
|
1035 |
+
source_id = graph.get_node_data(u).get("id")
|
1036 |
+
in_edges.append((source_id, n_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")}))
|
1037 |
+
|
1038 |
+
out_edges = []
|
1039 |
+
for u, v, edge_data in graph.out_edges(idx):
|
1040 |
+
target_id = graph.get_node_data(v).get("id")
|
1041 |
+
out_edges.append((n_id, target_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")}))
|
1042 |
+
|
1043 |
+
relationships[n_id] = {"in_edges": in_edges, "out_edges": out_edges}
|
1044 |
+
|
1045 |
+
return relationships
|
1046 |
+
|
1047 |
+
def find_nodes_by_properties(self, query: str = None, embedding: list = None,
|
1048 |
+
node_data: dict = None, similarity_threshold: float = 0.8) -> list:
|
1049 |
+
"""Find nodes based on properties."""
|
1050 |
+
try:
|
1051 |
+
graph_data = self._get_current_graph_data()
|
1052 |
+
graph = graph_data["graph"]
|
1053 |
+
node_map = graph_data["node_map"]
|
1054 |
+
matching_nodes = []
|
1055 |
+
|
1056 |
+
for n_id, idx in node_map.items():
|
1057 |
+
data = graph.get_node_data(idx)
|
1058 |
+
match_score = 0
|
1059 |
+
matches = 0
|
1060 |
+
|
1061 |
+
if query and query.lower() in data.get("query", "").lower():
|
1062 |
+
match_score += 1
|
1063 |
+
matches += 1
|
1064 |
+
|
1065 |
+
if embedding and "embedding" in data:
|
1066 |
+
sim = self.cosine_similarity(embedding, data["embedding"])
|
1067 |
+
|
1068 |
+
if sim >= similarity_threshold:
|
1069 |
+
match_score += sim
|
1070 |
+
matches += 1
|
1071 |
+
|
1072 |
+
if node_data:
|
1073 |
+
data_matches = sum(1 for k, v in node_data.items() if k in data and data[k] == v)
|
1074 |
+
|
1075 |
+
if data_matches > 0:
|
1076 |
+
match_score += data_matches / len(node_data)
|
1077 |
+
matches += 1
|
1078 |
+
|
1079 |
+
if matches > 0:
|
1080 |
+
matching_nodes.append({
|
1081 |
+
"node_id": n_id,
|
1082 |
+
"score": match_score / matches,
|
1083 |
+
"data": data
|
1084 |
+
})
|
1085 |
+
|
1086 |
+
matching_nodes.sort(key=lambda x: x["score"], reverse=True)
|
1087 |
+
|
1088 |
+
return matching_nodes
|
1089 |
+
except Exception as e:
|
1090 |
+
print(f"Error finding nodes by properties: {str(e)}")
|
1091 |
+
raise
|
1092 |
+
|
1093 |
+
def query_graph(self, query: str) -> str:
|
1094 |
+
"""Query the graph for a specific query, collecting data from the entire relevant subgraph."""
|
1095 |
+
graph_data = self._get_current_graph_data()
|
1096 |
+
graph = graph_data["graph"]
|
1097 |
+
node_map = graph_data["node_map"]
|
1098 |
+
target_node_id = None
|
1099 |
+
|
1100 |
+
for n_id, idx in node_map.items():
|
1101 |
+
if graph.get_node_data(idx).get("query") == query:
|
1102 |
+
target_node_id = n_id
|
1103 |
+
break
|
1104 |
+
|
1105 |
+
if not target_node_id:
|
1106 |
+
raise ValueError(f"Query node not found for: {query}")
|
1107 |
+
|
1108 |
+
datas = []
|
1109 |
+
start_idx = node_map[target_node_id]
|
1110 |
+
visited = set()
|
1111 |
+
stack = [start_idx]
|
1112 |
+
|
1113 |
+
while stack:
|
1114 |
+
current = stack.pop()
|
1115 |
+
|
1116 |
+
if current in visited:
|
1117 |
+
continue
|
1118 |
+
visited.add(current)
|
1119 |
+
current_data = graph.get_node_data(current)
|
1120 |
+
|
1121 |
+
if current_data.get("data") and current_data.get("data").strip():
|
1122 |
+
datas.append(current_data.get("data").strip())
|
1123 |
+
|
1124 |
+
for neighbor in graph.neighbors(current):
|
1125 |
+
if neighbor not in visited:
|
1126 |
+
stack.append(neighbor)
|
1127 |
+
|
1128 |
+
if not datas:
|
1129 |
+
print(f"No data found for: {query}")
|
1130 |
+
return ""
|
1131 |
+
|
1132 |
+
return "\n\n".join([f"Data {i+1}:\n{data}" for i, data in enumerate(datas)])
|
1133 |
+
|
1134 |
+
def prune_edges(self, max_edges: int = 1000):
|
1135 |
+
"""Prune excess edges while preserving node data."""
|
1136 |
+
print(f"Pruning edges to maximum {max_edges} edges...")
|
1137 |
+
graph_data = self._get_current_graph_data()
|
1138 |
+
graph = graph_data["graph"]
|
1139 |
+
all_edges = list(graph.edge_list())
|
1140 |
+
current_edges = len(all_edges)
|
1141 |
+
|
1142 |
+
if current_edges > max_edges:
|
1143 |
+
sorted_edges = sorted(all_edges, key=lambda x: x[2].get("weight", 1.0), reverse=True)
|
1144 |
+
edges_to_keep = set()
|
1145 |
+
|
1146 |
+
for edge in sorted_edges[:max_edges]:
|
1147 |
+
edges_to_keep.add((edge[0], edge[1]))
|
1148 |
+
|
1149 |
+
edges_to_remove = []
|
1150 |
+
for edge in all_edges:
|
1151 |
+
if (edge[0], edge[1]) not in edges_to_keep:
|
1152 |
+
edges_to_remove.append((edge[0], edge[1]))
|
1153 |
+
|
1154 |
+
for u, v in edges_to_remove:
|
1155 |
+
try:
|
1156 |
+
graph.remove_edge(u, v)
|
1157 |
+
except Exception as e:
|
1158 |
+
print(f"Error removing edge from {u} to {v}: {e}")
|
1159 |
+
|
1160 |
+
print(f"Pruned edges. Kept top {max_edges} edges by weight.")
|
1161 |
+
|
1162 |
+
print("No pruning required. Current edge count is within limits.")
|
1163 |
+
|
1164 |
+
def update_pagerank(self):
|
1165 |
+
"""Update PageRank values using Rustworkx's pagerank algorithm."""
|
1166 |
+
if not self.current_graph_id:
|
1167 |
+
print("No current graph selected. Cannot compute PageRank.")
|
1168 |
+
return
|
1169 |
+
|
1170 |
+
graph_data = self._get_current_graph_data()
|
1171 |
+
graph = graph_data["graph"]
|
1172 |
+
|
1173 |
+
try:
|
1174 |
+
pr = rx.pagerank(graph, weight_fn=lambda e: e.get("weight", 1.0))
|
1175 |
+
node_map = graph_data["node_map"]
|
1176 |
+
|
1177 |
+
for n_id, idx in node_map.items():
|
1178 |
+
node_data = graph.get_node_data(idx)
|
1179 |
+
node_data["pagerank"] = pr[idx]
|
1180 |
+
|
1181 |
+
print("PageRank updated successfully")
|
1182 |
+
except Exception as e:
|
1183 |
+
print(f"Error updating PageRank: {str(e)}")
|
1184 |
+
raise
|
1185 |
+
|
1186 |
+
def display_graph(self):
|
1187 |
+
"""Display the graph using PyVis."""
|
1188 |
+
graph_data = self._get_current_graph_data()
|
1189 |
+
graph = graph_data["graph"]
|
1190 |
+
node_map = graph_data["node_map"]
|
1191 |
+
net = Network(height="600px", width="100%", directed=True, bgcolor="#222222", font_color="white")
|
1192 |
+
net.options = {"physics": {"enabled": False}}
|
1193 |
+
all_nodes = set()
|
1194 |
+
all_edges = []
|
1195 |
+
|
1196 |
+
for (u, v), edge_data in zip(graph.edge_list(), graph.edges()):
|
1197 |
+
source_data = graph.get_node_data(u)
|
1198 |
+
target_data = graph.get_node_data(v)
|
1199 |
+
source_id = source_data.get("id")
|
1200 |
+
target_id = target_data.get("id")
|
1201 |
+
source_tooltip = f"Query: {source_data.get('query', 'N/A')}"
|
1202 |
+
target_tooltip = f"Query: {target_data.get('query', 'N/A')}"
|
1203 |
+
|
1204 |
+
if source_id not in all_nodes:
|
1205 |
+
net.add_node(source_id, label=source_id, title=source_tooltip, size=20, color="#00cc66")
|
1206 |
+
all_nodes.add(source_id)
|
1207 |
+
|
1208 |
+
if target_id not in all_nodes:
|
1209 |
+
net.add_node(target_id, label=target_id, title=target_tooltip, size=20, color="#00cc66")
|
1210 |
+
all_nodes.add(target_id)
|
1211 |
+
|
1212 |
+
edge_type = edge_data.get("type", "N/A")
|
1213 |
+
edge_weight = edge_data.get("weight", "N/A")
|
1214 |
+
edge_tooltip = f"Weight: {edge_weight}"
|
1215 |
+
all_edges.append({
|
1216 |
+
"from": source_id,
|
1217 |
+
"to": target_id,
|
1218 |
+
"label": edge_type,
|
1219 |
+
"title": edge_tooltip
|
1220 |
+
})
|
1221 |
+
|
1222 |
+
for edge in all_edges:
|
1223 |
+
net.add_edge(edge["from"], edge["to"], title=edge["title"], color="#cccccc")
|
1224 |
+
|
1225 |
+
net.options["layout"] = {"improvedLayout": True}
|
1226 |
+
net.options["interaction"] = {"dragNodes": True}
|
1227 |
+
|
1228 |
+
net.save_graph("temp_graph.html")
|
1229 |
+
|
1230 |
+
with open("temp_graph.html", "r", encoding="utf-8") as f:
|
1231 |
+
html_str = f.read()
|
1232 |
+
os.remove("temp_graph.html")
|
1233 |
+
return html_str
|
1234 |
+
|
1235 |
+
def verify_graph_integrity(self):
|
1236 |
+
"""Verify and fix graph integrity issues."""
|
1237 |
+
graph_data = self._get_current_graph_data()
|
1238 |
+
graph = graph_data["graph"]
|
1239 |
+
node_map = graph_data["node_map"]
|
1240 |
+
orphaned = []
|
1241 |
+
|
1242 |
+
for n_id, idx in node_map.items():
|
1243 |
+
if not graph.in_edges(idx) and not graph.out_edges(idx):
|
1244 |
+
orphaned.append(n_id)
|
1245 |
+
|
1246 |
+
if orphaned:
|
1247 |
+
print(f"Found orphaned nodes: {orphaned}")
|
1248 |
+
|
1249 |
+
invalid_edges = []
|
1250 |
+
for u, v in graph.edge_list():
|
1251 |
+
target_data = graph.get_node_data(v)
|
1252 |
+
|
1253 |
+
if target_data.get("graph_id") != self.current_graph_id:
|
1254 |
+
invalid_edges.append((graph.get_node_data(u).get("id"), target_data.get("id")))
|
1255 |
+
|
1256 |
+
if invalid_edges:
|
1257 |
+
print(f"Found invalid edges: {invalid_edges}")
|
1258 |
+
edges_to_remove = []
|
1259 |
+
|
1260 |
+
for u, v in graph.edge_list():
|
1261 |
+
if graph.get_node_data(v).get("graph_id") != self.current_graph_id:
|
1262 |
+
edges_to_remove.append((u, v))
|
1263 |
+
|
1264 |
+
for u, v in edges_to_remove:
|
1265 |
+
try:
|
1266 |
+
graph.remove_edge(u, v)
|
1267 |
+
except Exception as e:
|
1268 |
+
Exception(f"Error removing invalid edge from {u} to {v}: {e}")
|
1269 |
+
|
1270 |
+
print("Graph integrity verified successfully")
|
1271 |
+
|
1272 |
+
return True
|
1273 |
+
|
1274 |
+
def verify_graph_consistency(self):
|
1275 |
+
"""Verify consistency of the in-memory graph."""
|
1276 |
+
graph_data = self._get_current_graph_data()
|
1277 |
+
graph = graph_data["graph"]
|
1278 |
+
node_map = graph_data["node_map"]
|
1279 |
+
|
1280 |
+
for n_id, idx in node_map.items():
|
1281 |
+
node_data = graph.get_node_data(idx)
|
1282 |
+
|
1283 |
+
if node_data.get("id") is None or node_data.get("query") is None:
|
1284 |
+
raise ValueError("Found nodes with missing required properties")
|
1285 |
+
|
1286 |
+
for edge_data in graph.edges():
|
1287 |
+
if edge_data.get("type") is None or edge_data.get("weight") is None:
|
1288 |
+
raise ValueError("Found relationships with missing required properties")
|
1289 |
+
|
1290 |
+
print("Graph consistency verified successfully")
|
1291 |
+
|
1292 |
+
return True
|
1293 |
+
|
1294 |
+
async def close(self):
|
1295 |
+
"""Properly cleanup all resources."""
|
1296 |
+
try:
|
1297 |
+
if hasattr(self, 'executor'):
|
1298 |
+
self.executor.shutdown(wait=True)
|
1299 |
+
|
1300 |
+
if hasattr(self, 'crawler'):
|
1301 |
+
await asyncio.shield(self.crawler.cleanup_expired_sessions())
|
1302 |
+
await asyncio.shield(self.crawler.cleanup_browser_context(getattr(self, "session_id", None)))
|
1303 |
+
except Exception as e:
|
1304 |
+
print(f"Error during cleanup: {e}")
|
1305 |
+
|
1306 |
+
@staticmethod
|
1307 |
+
def cosine_similarity(v1: List[float], v2: List[float]) -> float:
|
1308 |
+
"""Calculate cosine similarity between two vectors."""
|
1309 |
+
try:
|
1310 |
+
v1_array = np.array(v1)
|
1311 |
+
v2_array = np.array(v2)
|
1312 |
+
return np.dot(v1_array, v2_array) / (np.linalg.norm(v1_array) * np.linalg.norm(v2_array))
|
1313 |
+
except Exception as e:
|
1314 |
+
print(f"Error calculating cosine similarity: {str(e)}")
|
1315 |
+
return 0.0
|
1316 |
+
|
1317 |
+
if __name__ == "__main__":
|
1318 |
+
import os
|
1319 |
+
from dotenv import load_dotenv
|
1320 |
+
from src.reasoning.reasoner import Reasoner
|
1321 |
+
from src.evaluation.evaluator import Evaluator
|
1322 |
+
|
1323 |
+
load_dotenv()
|
1324 |
+
|
1325 |
+
graph_search = GraphRAG(num_workers=24)
|
1326 |
+
evaluator = Evaluator()
|
1327 |
+
reasoner = Reasoner()
|
1328 |
+
|
1329 |
+
async def test_graph_search():
|
1330 |
+
# Sample data for testing
|
1331 |
+
queries = [
|
1332 |
+
"""In the context of global economic recovery and energy security concerns, provide an in-depth comparative assessment of the renewable energy policies among G20 countries.
|
1333 |
+
Specifically, examine how short-term economic stimulus measures intersect with long-term decarbonization commitments, including:
|
1334 |
+
1. Carbon pricing mechanisms
|
1335 |
+
2. Subsidies for emerging technologies (such as green hydrogen and battery storage)
|
1336 |
+
3. Cross-border climate finance initiatives
|
1337 |
+
|
1338 |
+
Highlight the unique challenges faced by both advanced and emerging economies in addressing:
|
1339 |
+
1. Energy poverty
|
1340 |
+
2. Supply chain disruptions
|
1341 |
+
3. Geopolitical tensions (e.g., the Russia-Ukraine conflict)
|
1342 |
+
|
1343 |
+
Discuss how these factors influence policy effectiveness, and evaluate the degree to which each country is on track to meet—or exceed—its Paris Agreement targets.
|
1344 |
+
Note any significant policy gaps, regional collaborations, or innovative best practices.
|
1345 |
+
Lastly, provide a forward-looking perspective on how these renewable energy strategies may evolve over the next decade, considering:
|
1346 |
+
1. Technological breakthroughs
|
1347 |
+
2. Global market trends
|
1348 |
+
3. Potential climate-related disasters
|
1349 |
+
|
1350 |
+
Present your analysis as a detailed, well-formatted report.""",
|
1351 |
+
"""Analyse the impact of 'hot-money' on the value of Indian Rupee and answer the following questions:-
|
1352 |
+
1. How does it affect the exchange rate?
|
1353 |
+
2. How can it be mitigated/eliminated?
|
1354 |
+
3. Why is it a problem?
|
1355 |
+
4. What are the consequences?
|
1356 |
+
5. What are the alternatives?
|
1357 |
+
- Evaluate the alternatives for pros and cons.
|
1358 |
+
- Evaluate the impact of alternatives on the exchange rate.
|
1359 |
+
- How can they be implemented?
|
1360 |
+
- What are the consequences of each alternative?
|
1361 |
+
- Evaluate the feasibility of the alternatives.
|
1362 |
+
- Pick top 5 alternatives and justify your choices in detail.
|
1363 |
+
6. What are the implications for the Indian economy? Furthermore:-
|
1364 |
+
- Evaluate the impact of the chosen alternatives on the Indian economy.""",
|
1365 |
+
"""Inflation has been an intrinsic past of human civilization since the very beginning. Answer the following questions:-
|
1366 |
+
1. How true is the above statement?
|
1367 |
+
2. What are the causes of inflation?
|
1368 |
+
3. What are the consequences of inflation?
|
1369 |
+
4. Can we completely eliminate inflation?""",
|
1370 |
+
"""Perform a detailed comparison between the ancient Greece and Roman civilizations.
|
1371 |
+
1. What were the key differences between the two civilizations?
|
1372 |
+
- Evaluate the differences in governance, society, and culture
|
1373 |
+
- Evaluate the differences in economy, trade, and military
|
1374 |
+
- Evaluate the differences in technology and infrastructure
|
1375 |
+
2. What were the similarities between the two civilizations?
|
1376 |
+
- Evaluate the similarities in governance, society, and culture
|
1377 |
+
- Evaluate the similarities in economy, trade, and military
|
1378 |
+
- Evaluate the similarities in technology and infrastructure
|
1379 |
+
3. How did these two civilizations influence each other?
|
1380 |
+
- Evaluate the influence of one civilization on the other
|
1381 |
+
4. How did these two civilizations influence the modern world?
|
1382 |
+
5. Was there another civilization that influenced these two? If yes, how?""",
|
1383 |
+
"""Evaluate the long-term effects of colonialism on economic development in Asia:-
|
1384 |
+
1. Include case studies of at least five different countries
|
1385 |
+
2. Analyze how these effects differ based on colonial power, time of independence, and resource distribution
|
1386 |
+
- Evaluate the impact of colonialism on the economy of the country
|
1387 |
+
- Evaluate the impact of colonialism on the economy of the region
|
1388 |
+
- Evaluate the impact of colonialism on the economy of the world
|
1389 |
+
3. How do these effects compare to Africa?"""
|
1390 |
+
]
|
1391 |
+
follow_on_queries = [
|
1392 |
+
"How is 'hot-money' related to the current economic situation in India?",
|
1393 |
+
"What is inflation?",
|
1394 |
+
"Did ancient Greece and Rome have any impact on modern democracy? If yes, how?",
|
1395 |
+
"Did colonialism have any impact on the trade between Africa and Asia, both in colonial and post-colonial times? If yes, how?"
|
1396 |
+
]
|
1397 |
+
|
1398 |
+
while True:
|
1399 |
+
print("\n\nEnter query (finish input with an empty line):")
|
1400 |
+
query_lines = []
|
1401 |
+
|
1402 |
+
while True:
|
1403 |
+
line = input()
|
1404 |
+
|
1405 |
+
if line.strip() == "":
|
1406 |
+
break
|
1407 |
+
query_lines.append(line)
|
1408 |
+
|
1409 |
+
query = "\n".join(query_lines).strip()
|
1410 |
+
|
1411 |
+
if query.strip().lower() == "exit":
|
1412 |
+
break
|
1413 |
+
print("\n\n" + "="*15 + " Processing Query " + "="*15 + "\n\n")
|
1414 |
+
|
1415 |
+
await graph_search.process_graph(query, similarity_threshold=0.8, relevance_threshold=0.8)
|
1416 |
+
|
1417 |
+
answer = graph_search.query_graph(query)
|
1418 |
+
|
1419 |
+
response = ""
|
1420 |
+
async for chunk in reasoner.reason(query, answer):
|
1421 |
+
response += chunk
|
1422 |
+
print(response, end="", flush=True)
|
1423 |
+
|
1424 |
+
graph_search.display_graph()
|
1425 |
+
|
1426 |
+
evaluation = await evaluator.evaluate_response(query, response, [answer])
|
1427 |
+
print(f"Faithfulness: {evaluation['faithfulness']}")
|
1428 |
+
print(f"Answer Relevancy: {evaluation['answer relevancy']}")
|
1429 |
+
print(f"Context Utilization: {evaluation['contextual recall']}")
|
1430 |
+
|
1431 |
+
await graph_search.close()
|
1432 |
+
|
1433 |
+
asyncio.run(test_graph_search())
|