Hemang Thakur commited on
Commit
25f9610
·
1 Parent(s): d8479e9

removed neo4j and replaced with rustworkx

Browse files
frontend/src/Components/AiComponents/Graph.js CHANGED
@@ -2,19 +2,19 @@ import React, { useState, useEffect } from 'react';
2
  import { FaTimes } from 'react-icons/fa';
3
  import './Graph.css';
4
 
5
- export default function Graph({ open, onClose, payload }) {
6
  const [graphHtml, setGraphHtml] = useState("");
7
  const [loading, setLoading] = useState(true);
8
  const [error, setError] = useState("");
9
 
10
  useEffect(() => {
11
- if (open && payload) {
12
  setLoading(true);
13
  setError("");
14
  fetch("/action/graph", {
15
  method: "POST",
16
  headers: { "Content-Type": "application/json" },
17
- body: JSON.stringify(payload)
18
  })
19
  .then(res => res.json())
20
  .then(data => {
 
2
  import { FaTimes } from 'react-icons/fa';
3
  import './Graph.css';
4
 
5
+ export default function Graph({ open, onClose }) {
6
  const [graphHtml, setGraphHtml] = useState("");
7
  const [loading, setLoading] = useState(true);
8
  const [error, setError] = useState("");
9
 
10
  useEffect(() => {
11
+ if (open) {
12
  setLoading(true);
13
  setError("");
14
  fetch("/action/graph", {
15
  method: "POST",
16
  headers: { "Content-Type": "application/json" },
17
+ body: JSON.stringify({})
18
  })
19
  .then(res => res.json())
20
  .then(data => {
frontend/src/Components/IntialSetting.js CHANGED
@@ -90,9 +90,9 @@ function IntialSetting(props) {
90
  const modelAPIKeys = form.elements["model-api"].value;
91
  const braveAPIKey = form.elements["brave-api"].value;
92
  const proxyList = form.elements["proxy-list"].value;
93
- const neo4jURL = form.elements["neo4j-url"].value;
94
- const neo4jUsername = form.elements["neo4j-username"].value;
95
- const neo4jPassword = form.elements["neo4j-password"].value;
96
 
97
  // Validate required fields and collect missing field names
98
  const missingFields = [];
@@ -100,9 +100,9 @@ function IntialSetting(props) {
100
  if (!modelName || modelName.trim() === "") missingFields.push("Model Name");
101
  if (!modelAPIKeys || modelAPIKeys.trim() === "") missingFields.push("Model API Key");
102
  if (!braveAPIKey || braveAPIKey.trim() === "") missingFields.push("Brave Search API Key");
103
- if (!neo4jURL || neo4jURL.trim() === "") missingFields.push("Neo4j URL");
104
- if (!neo4jUsername || neo4jUsername.trim() === "") missingFields.push("Neo4j Username");
105
- if (!neo4jPassword || neo4jPassword.trim() === "") missingFields.push("Neo4j Password");
106
 
107
  // If any required fields are missing, show an error notification
108
  if (missingFields.length > 0) {
@@ -120,9 +120,9 @@ function IntialSetting(props) {
120
  "Model_Name": modelName,
121
  "Model_API_Keys": modelAPIKeys,
122
  "Brave_Search_API_Key": braveAPIKey,
123
- "Neo4j_URL": neo4jURL,
124
- "Neo4j_Username": neo4jUsername,
125
- "Neo4j_Password": neo4jPassword,
126
  "Model_Temperature": modelTemperature,
127
  "Model_Top_P": modelTopP,
128
  };
@@ -282,7 +282,7 @@ function IntialSetting(props) {
282
  </div>
283
 
284
  {/* Neo4j Configuration */}
285
- <div className="form-group">
286
  <label htmlFor="neo4j-url">Neo4j URL</label>
287
  <input
288
  type="text"
@@ -321,7 +321,7 @@ function IntialSetting(props) {
321
  {showPassword ? <FaEyeSlash /> : <FaEye />}
322
  </IconButton>
323
  </div>
324
- </div>
325
 
326
  {/* Model Temperature and Top-P */}
327
  <div className="form-group">
 
90
  const modelAPIKeys = form.elements["model-api"].value;
91
  const braveAPIKey = form.elements["brave-api"].value;
92
  const proxyList = form.elements["proxy-list"].value;
93
+ // const neo4jURL = form.elements["neo4j-url"].value;
94
+ // const neo4jUsername = form.elements["neo4j-username"].value;
95
+ // const neo4jPassword = form.elements["neo4j-password"].value;
96
 
97
  // Validate required fields and collect missing field names
98
  const missingFields = [];
 
100
  if (!modelName || modelName.trim() === "") missingFields.push("Model Name");
101
  if (!modelAPIKeys || modelAPIKeys.trim() === "") missingFields.push("Model API Key");
102
  if (!braveAPIKey || braveAPIKey.trim() === "") missingFields.push("Brave Search API Key");
103
+ // if (!neo4jURL || neo4jURL.trim() === "") missingFields.push("Neo4j URL");
104
+ // if (!neo4jUsername || neo4jUsername.trim() === "") missingFields.push("Neo4j Username");
105
+ // if (!neo4jPassword || neo4jPassword.trim() === "") missingFields.push("Neo4j Password");
106
 
107
  // If any required fields are missing, show an error notification
108
  if (missingFields.length > 0) {
 
120
  "Model_Name": modelName,
121
  "Model_API_Keys": modelAPIKeys,
122
  "Brave_Search_API_Key": braveAPIKey,
123
+ // "Neo4j_URL": neo4jURL,
124
+ // "Neo4j_Username": neo4jUsername,
125
+ // "Neo4j_Password": neo4jPassword,
126
  "Model_Temperature": modelTemperature,
127
  "Model_Top_P": modelTopP,
128
  };
 
282
  </div>
283
 
284
  {/* Neo4j Configuration */}
285
+ {/* <div className="form-group">
286
  <label htmlFor="neo4j-url">Neo4j URL</label>
287
  <input
288
  type="text"
 
321
  {showPassword ? <FaEyeSlash /> : <FaEye />}
322
  </IconButton>
323
  </div>
324
+ </div> */}
325
 
326
  {/* Model Temperature and Top-P */}
327
  <div className="form-group">
main.py CHANGED
@@ -41,7 +41,7 @@ def initialize_components():
41
 
42
  from src.search.search_engine import SearchEngine
43
  from src.query_processing.query_processor import QueryProcessor
44
- from src.rag.neo4j_graphrag import Neo4jGraphRAG
45
  from src.evaluation.evaluator import Evaluator
46
  from src.reasoning.reasoner import Reasoner
47
  from src.crawl.crawler import CustomCrawler
@@ -53,7 +53,7 @@ def initialize_components():
53
  SESSION_STORE['search_engine'] = SearchEngine()
54
  SESSION_STORE['query_processor'] = QueryProcessor()
55
  SESSION_STORE['crawler'] = CustomCrawler(max_concurrent_requests=1000)
56
- SESSION_STORE['graph_rag'] = Neo4jGraphRAG(num_workers=os.cpu_count() * 2)
57
  SESSION_STORE['evaluator'] = Evaluator()
58
  SESSION_STORE['reasoner'] = Reasoner()
59
  SESSION_STORE['model'] = manager.get_llm()
@@ -580,12 +580,10 @@ def action_sources(payload: Dict[str, Any]) -> Dict[str, Any]:
580
 
581
  # Define the route for graph action to display the graph
582
  @app.post("/action/graph")
583
- def action_graph(payload: Dict[str, Any]) -> Dict[str, Any]:
584
  state = SESSION_STORE
585
-
586
  try:
587
- q = payload.get("query", "")
588
- html_str = state['graph_rag'].display_graph(q)
589
 
590
  return {"result": html_str}
591
  except Exception as e:
@@ -621,9 +619,9 @@ async def update_settings(data: Dict[str, Any]):
621
  multiple_api_keys = data.get("Model_API_Keys", "").strip()
622
  brave_api_key = data.get("Brave_Search_API_Key", "").strip()
623
  proxy_list = data.get("Proxy_List", "").strip()
624
- neo4j_url = data.get("Neo4j_URL", "").strip()
625
- neo4j_username = data.get("Neo4j_Username", "").strip()
626
- neo4j_password = data.get("Neo4j_Password", "").strip()
627
  model_temperature = str(data.get("Model_Temperature", 0.0))
628
  model_top_p = str(data.get("Model_Top_P", 1.0))
629
 
@@ -637,9 +635,9 @@ async def update_settings(data: Dict[str, Any]):
637
  env_updates.update(px)
638
 
639
  env_updates["BRAVE_API_KEY"] = brave_api_key
640
- env_updates["NEO4J_URI"] = neo4j_url
641
- env_updates["NEO4J_USER"] = neo4j_username
642
- env_updates["NEO4J_PASSWORD"] = neo4j_password
643
  env_updates["MODEL_PROVIDER"] = prov_lower
644
  env_updates["MODEL_NAME"] = model_name
645
  env_updates["MODEL_TEMPERATURE"] = model_temperature
 
41
 
42
  from src.search.search_engine import SearchEngine
43
  from src.query_processing.query_processor import QueryProcessor
44
+ from src.rag.graph_rag import GraphRAG
45
  from src.evaluation.evaluator import Evaluator
46
  from src.reasoning.reasoner import Reasoner
47
  from src.crawl.crawler import CustomCrawler
 
53
  SESSION_STORE['search_engine'] = SearchEngine()
54
  SESSION_STORE['query_processor'] = QueryProcessor()
55
  SESSION_STORE['crawler'] = CustomCrawler(max_concurrent_requests=1000)
56
+ SESSION_STORE['graph_rag'] = GraphRAG(num_workers=os.cpu_count() * 2)
57
  SESSION_STORE['evaluator'] = Evaluator()
58
  SESSION_STORE['reasoner'] = Reasoner()
59
  SESSION_STORE['model'] = manager.get_llm()
 
580
 
581
  # Define the route for graph action to display the graph
582
  @app.post("/action/graph")
583
+ def action_graph() -> Dict[str, Any]:
584
  state = SESSION_STORE
 
585
  try:
586
+ html_str = state['graph_rag'].display_graph()
 
587
 
588
  return {"result": html_str}
589
  except Exception as e:
 
619
  multiple_api_keys = data.get("Model_API_Keys", "").strip()
620
  brave_api_key = data.get("Brave_Search_API_Key", "").strip()
621
  proxy_list = data.get("Proxy_List", "").strip()
622
+ # neo4j_url = data.get("Neo4j_URL", "").strip()
623
+ # neo4j_username = data.get("Neo4j_Username", "").strip()
624
+ # neo4j_password = data.get("Neo4j_Password", "").strip()
625
  model_temperature = str(data.get("Model_Temperature", 0.0))
626
  model_top_p = str(data.get("Model_Top_P", 1.0))
627
 
 
635
  env_updates.update(px)
636
 
637
  env_updates["BRAVE_API_KEY"] = brave_api_key
638
+ # env_updates["NEO4J_URI"] = neo4j_url
639
+ # env_updates["NEO4J_USER"] = neo4j_username
640
+ # env_updates["NEO4J_PASSWORD"] = neo4j_password
641
  env_updates["MODEL_PROVIDER"] = prov_lower
642
  env_updates["MODEL_NAME"] = model_name
643
  env_updates["MODEL_TEMPERATURE"] = model_temperature
requirements.txt CHANGED
@@ -18,6 +18,7 @@ langchain_xai==0.1.1
18
  langgraph==0.2.62
19
  model2vec==0.3.3
20
  neo4j==5.26.0
 
21
  openai==1.59.3
22
  protobuf==4.23.4
23
  PyPDF2==3.0.1
 
18
  langgraph==0.2.62
19
  model2vec==0.3.3
20
  neo4j==5.26.0
21
+ rustworkx==0.16.0
22
  openai==1.59.3
23
  protobuf==4.23.4
24
  PyPDF2==3.0.1
src/rag/graph_rag.py ADDED
@@ -0,0 +1,1433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import asyncio
5
+ import torch
6
+ import uuid
7
+ import rustworkx as rx
8
+ import numpy as np
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from typing import List, Dict, Any
11
+ from pyvis.network import Network
12
+ from src.query_processing.late_chunking.late_chunker import LateChunker
13
+ from src.query_processing.query_processor import QueryProcessor
14
+ from src.reasoning.reasoner import Reasoner
15
+ from src.utils.api_key_manager import APIKeyManager
16
+ from src.search.search_engine import SearchEngine
17
+ from src.crawl.crawler import CustomCrawler #, Crawler
18
+ from sentence_transformers import SentenceTransformer
19
+ from bert_score.scorer import BERTScorer
20
+
21
+ class GraphRAG:
22
+ def __init__(self, num_workers: int = 1):
23
+ """Initialize graph and required components."""
24
+ # Dictionary to store multiple graphs
25
+ self.graphs = {}
26
+ self.current_graph_id = None
27
+
28
+ # Component initialization
29
+ self.num_workers = num_workers
30
+ self.search_engine = SearchEngine()
31
+ self.query_processor = QueryProcessor()
32
+ self.reasoner = Reasoner()
33
+ # self.crawler = Crawler(verbose=True)
34
+ self.custom_crawler = CustomCrawler(max_concurrent_requests=1000)
35
+ self.chunking = LateChunker()
36
+ self.llm = APIKeyManager().get_llm()
37
+
38
+ # Model initialization
39
+ self.model = SentenceTransformer(
40
+ "dunzhang/stella_en_400M_v5",
41
+ trust_remote_code=True,
42
+ device="cuda" if torch.cuda.is_available() else "cpu"
43
+ )
44
+ self.scorer = BERTScorer(
45
+ model_type="roberta-base",
46
+ lang="en",
47
+ rescale_with_baseline=True,
48
+ device="cuda" if torch.cuda.is_available() else "cpu"
49
+ )
50
+
51
+ # Counters and tracking
52
+ self.root_node_id = "QR"
53
+ self.node_counter = 0
54
+ self.sub_node_counter = 0
55
+ self.cross_connections = set()
56
+
57
+ # Thread pool
58
+ self.executor = ThreadPoolExecutor(max_workers=self.num_workers)
59
+
60
+ # Event callback
61
+ self.on_event_callback = None
62
+
63
+ def set_on_event_callback(self, callback):
64
+ """Register a single callback to be triggered for various event types."""
65
+ self.on_event_callback = callback
66
+
67
+ async def emit_event(self, event_type: str, data: dict):
68
+ """Helper method to safely emit an event if a callback is registered."""
69
+ if self.on_event_callback:
70
+ if asyncio.iscoroutinefunction(self.on_event_callback):
71
+ return await self.on_event_callback(event_type, data)
72
+ else:
73
+ return self.on_event_callback(event_type, data)
74
+
75
+ def _get_current_graph_data(self):
76
+ if self.current_graph_id is None or self.current_graph_id not in self.graphs:
77
+ raise Exception("Error: No current graph selected")
78
+
79
+ return self.graphs[self.current_graph_id]
80
+
81
+ def add_node(self, node_id: str, query: str, data: str = "", role: str = None):
82
+ """Add a node to the current graph."""
83
+ graph_data = self._get_current_graph_data()
84
+ graph = graph_data["graph"]
85
+ node_map = graph_data["node_map"]
86
+
87
+ # Generate embedding
88
+ embedding = self.model.encode(query).tolist()
89
+ node_data = {
90
+ "id": node_id,
91
+ "query": query,
92
+ "data": data,
93
+ "role": role,
94
+ "embedding": embedding,
95
+ "pagerank": 0,
96
+ "graph_id": self.current_graph_id
97
+ }
98
+ node_index = graph.add_node(node_data)
99
+ node_map[node_id] = node_index
100
+
101
+ print(f"Added node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}'")
102
+
103
+ def _has_path(self, source_idx: int, target_idx: int) -> bool:
104
+ """Helper method to check if there is a path from source to target in the current graph."""
105
+ graph_data = self._get_current_graph_data()
106
+ graph = graph_data["graph"]
107
+ visited = set()
108
+ stack = [source_idx]
109
+
110
+ while stack:
111
+ current = stack.pop()
112
+
113
+ if current == target_idx:
114
+ return True
115
+
116
+ if current in visited:
117
+ continue
118
+
119
+ visited.add(current)
120
+ for neighbor in graph.neighbors(current):
121
+ stack.append(neighbor)
122
+
123
+ return False
124
+
125
+ def add_edge(self, node1: str, node2: str, weight: float = 1.0, relationship_type: str = None):
126
+ """Add an edge between two nodes in a way that preserves a DAG structure."""
127
+ if self.current_graph_id is None:
128
+ raise Exception("Error: No current graph selected")
129
+
130
+ if node1 == node2:
131
+ print(f"Cannot add edge to the same node {node1}!")
132
+ return
133
+
134
+ graph_data = self._get_current_graph_data()
135
+ graph = graph_data["graph"]
136
+ node_map = graph_data["node_map"]
137
+
138
+ if node1 not in node_map or node2 not in node_map:
139
+ print(f"One or both nodes {node1}, {node2} do not exist in the current graph.")
140
+ return
141
+
142
+ idx1 = node_map[node1]
143
+ idx2 = node_map[node2]
144
+
145
+ # Check if adding this edge would create a cycle (i.e. if there is a path from node2 to node1)
146
+ if self._has_path(idx2, idx1):
147
+ print(f"An edge between {node1} -> {node2} already exists or would create a cycle!")
148
+ return
149
+
150
+ if relationship_type and weight:
151
+ edge_data = {"type": relationship_type, "weight": weight}
152
+ graph.add_edge(idx1, idx2, edge_data)
153
+ else:
154
+ raise ValueError("Error: Relationship type and weight must be provided")
155
+ print(f"Added edge between '{node1}' and '{node2}' in graph '{self.current_graph_id}' (type='{relationship_type}', weight={weight})")
156
+
157
+ def edge_exists(self, node1: str, node2: str) -> bool:
158
+ """Check if an edge exists between two nodes."""
159
+ graph_data = self._get_current_graph_data()
160
+ graph = graph_data["graph"]
161
+ node_map = graph_data["node_map"]
162
+
163
+ if node1 not in node_map or node2 not in node_map:
164
+ return False
165
+ idx1 = node_map[node1]
166
+ idx2 = node_map[node2]
167
+
168
+ for edge in graph.out_edges(idx1):
169
+ if edge[1] == idx2:
170
+ return True
171
+
172
+ return False
173
+
174
+ def graph_exists(self) -> bool:
175
+ """Check if a graph exists."""
176
+ return self.current_graph_id is not None and self.current_graph_id in self.graphs and len(self.graphs[self.current_graph_id]["node_map"]) > 0
177
+
178
+ def get_graphs(self) -> list:
179
+ """Get detailed information about all existing graphs and their nodes."""
180
+ result = []
181
+ for graph_id, data in self.graphs.items():
182
+ metadata = data["metadata"]
183
+ node_map = data["node_map"]
184
+ graph = data["graph"]
185
+ nodes_info = []
186
+
187
+ for node_id, idx in node_map.items():
188
+ node_data = graph.get_node_data(idx)
189
+ nodes_info.append({
190
+ "id": node_data.get("id"),
191
+ "query": node_data.get("query"),
192
+ "data": node_data.get("data"),
193
+ "role": node_data.get("role"),
194
+ "pagerank": node_data.get("pagerank")
195
+ })
196
+ edge_count = len(graph.edge_list())
197
+ result.append({
198
+ "graph_info": {
199
+ "graph_id": graph_id,
200
+ "created": metadata.get("created"),
201
+ "updated": metadata.get("updated"),
202
+ "node_count": len(node_map),
203
+ "edge_count": edge_count,
204
+ "nodes": nodes_info
205
+ }
206
+ })
207
+
208
+ result.sort(key=lambda x: x["graph_info"]["created"], reverse=True)
209
+ return result
210
+
211
+ def select_graph(self, graph_id: str) -> bool:
212
+ """Select a specific graph as the current working graph."""
213
+ if graph_id in self.graphs:
214
+ self.current_graph_id = graph_id
215
+ return True
216
+ return False
217
+
218
+ def create_new_graph(self) -> str:
219
+ """Create a new graph instance and its ID."""
220
+ graph_id = str(uuid.uuid4())
221
+ graph = rx.PyDiGraph()
222
+ node_map = {}
223
+ metadata = {
224
+ "id": graph_id,
225
+ "created": time.time(),
226
+ "updated": time.time()
227
+ }
228
+ self.graphs[graph_id] = {"graph": graph, "node_map": node_map, "metadata": metadata}
229
+ self.current_graph_id = graph_id
230
+
231
+ return graph_id
232
+
233
+ def load_graph(self, node_id: str) -> bool:
234
+ """Load an existing graph structure from memory based on a node ID."""
235
+
236
+ for gid, data in self.graphs.items():
237
+ if node_id in data["node_map"]:
238
+ self.current_graph_id = gid
239
+
240
+ for n_id in data["node_map"].keys():
241
+ if "SQ" in n_id:
242
+ num = int(''.join(filter(str.isdigit, n_id)) or 0)
243
+ self.node_counter = max(self.node_counter, num)
244
+ elif "SSQ" in n_id:
245
+ num = int(''.join(filter(str.isdigit, n_id)) or 0)
246
+ self.sub_node_counter = max(self.sub_node_counter, num)
247
+
248
+ self.node_counter += 1
249
+ self.sub_node_counter += 1
250
+ graph = data["graph"]
251
+ node_map = data["node_map"]
252
+
253
+ for (u, v), edge_data in zip(graph.edge_list(), graph.edges()):
254
+ if edge_data.get("type") == "logical":
255
+ source_id = graph.get_node_data(u).get("id")
256
+ target_id = graph.get_node_data(v).get("id")
257
+ connection = tuple(sorted([source_id, target_id]))
258
+ self.cross_connections.add(connection)
259
+
260
+ print(f"Successfully loaded graph. Current counters - Node: {self.node_counter}, Sub: {self.sub_node_counter}")
261
+ return True
262
+
263
+ print(f"Graph with node_id {node_id} not found.")
264
+
265
+ return False
266
+
267
+ async def modify_graph(self, new_query: str, similar_node_id: str, session_id: str = None):
268
+ """Modify an existing graph structure by integrating a new query."""
269
+ graph_data = self._get_current_graph_data()
270
+ graph = graph_data["graph"]
271
+ node_map = graph_data["node_map"]
272
+
273
+ async def add_as_sibling(node_id: str, query: str):
274
+ if node_id not in node_map:
275
+ raise ValueError(f"Node {node_id} not found")
276
+
277
+ idx = node_map[node_id]
278
+ in_edges = graph.in_edges(idx)
279
+
280
+ if not in_edges:
281
+ raise ValueError(f"No parent found for node {node_id}")
282
+
283
+ parent_idx = in_edges[0][0]
284
+ parent_data = graph.get_node_data(parent_idx)
285
+ parent_id = parent_data.get("id")
286
+
287
+ if "SQ" in node_id:
288
+ self.node_counter += 1
289
+ new_node_id = f"SQ{self.node_counter}"
290
+ else:
291
+ self.sub_node_counter += 1
292
+ new_node_id = f"SSQ{self.sub_node_counter}"
293
+
294
+ self.add_node(new_node_id, query, role="independent")
295
+ self.add_edge(parent_id, new_node_id, relationship_type=in_edges[0][2].get("type"))
296
+
297
+ return new_node_id
298
+
299
+ async def add_as_child(node_id: str, query: str):
300
+ if "SQ" in node_id:
301
+ self.sub_node_counter += 1
302
+ new_node_id = f"SSQ{self.sub_node_counter}"
303
+ else:
304
+ self.node_counter += 1
305
+ new_node_id = f"SQ{self.node_counter}"
306
+
307
+ self.add_node(new_node_id, query, role="dependent")
308
+ self.add_edge(node_id, new_node_id, relationship_type="logical")
309
+
310
+ return new_node_id
311
+
312
+ def collect_graph_context() -> list:
313
+ """Collect context from existing graph nodes."""
314
+ graph_data = self._get_current_graph_data()
315
+ graph = graph_data["graph"]
316
+ node_map = graph_data["node_map"]
317
+ nodes = []
318
+
319
+ for n_id, idx in node_map.items():
320
+ if n_id == self.root_node_id:
321
+ continue
322
+ node_data = graph.get_node_data(idx)
323
+ nodes.append({
324
+ "id": node_data.get("id"),
325
+ "query": node_data.get("query"),
326
+ "role": node_data.get("role")
327
+ })
328
+
329
+ nodes.sort(key=lambda x: (0 if x["id"].startswith("SQ") else (1 if x["id"].startswith("SSQ") else 2), x["id"]))
330
+ level_queries = {}
331
+ current_sq = None
332
+
333
+ for node in nodes:
334
+ node_id = node["id"]
335
+ if node_id.startswith("SQ"):
336
+ current_sq = node_id
337
+
338
+ if current_sq not in level_queries:
339
+ level_queries[current_sq] = {
340
+ "originalquery": node["query"],
341
+ "subqueries": []
342
+ }
343
+ level_queries[current_sq]["subqueries"].append({
344
+ "subquery": node["query"],
345
+ "role": node["role"],
346
+ "dependson": []
347
+ })
348
+
349
+ elif node_id.startswith("SSQ") and current_sq:
350
+ level_queries[current_sq]["subqueries"].append({
351
+ "subquery": node["query"],
352
+ "role": node["role"],
353
+ "dependson": []
354
+ })
355
+
356
+ return list(level_queries.values())
357
+
358
+ if similar_node_id not in node_map:
359
+ raise Exception(f"Node {similar_node_id} not found")
360
+
361
+ similar_node_data = graph.get_node_data(node_map[similar_node_id])
362
+ has_parent = len(graph.in_edges(node_map[similar_node_id])) > 0
363
+
364
+ context = collect_graph_context()
365
+ if similar_node_data.get("role") == "independent":
366
+ if has_parent:
367
+ new_node_id = await add_as_sibling(similar_node_id, new_query)
368
+ else:
369
+ new_node_id = await add_as_child(similar_node_id, new_query)
370
+ else:
371
+ new_node_id = await add_as_child(similar_node_id, new_query)
372
+
373
+ await self.build_graph(
374
+ query=new_query,
375
+ parent_node_id=new_node_id,
376
+ depth=1 if "SQ" in new_node_id else 2,
377
+ context=context,
378
+ session_id=session_id
379
+ )
380
+
381
+ async def build_graph(self, query: str, data: str = None, parent_node_id: str = None,
382
+ depth: int = 0, threshold: float = 0.8, recurse: bool = True,
383
+ context: list = None, session_id: str = None, max_tokens_allowed: int = 128000):
384
+ """Build a new graph structure in memory."""
385
+ async def process_node(node_id: str, sub_query: str, session_id: str, future: asyncio.Future, depth=depth, max_tokens_allowed=max_tokens_allowed):
386
+ try:
387
+ optimized_query = await self.search_engine.generate_optimized_query(sub_query)
388
+ results = await self.search_engine.search(
389
+ query=optimized_query,
390
+ num_results=10,
391
+ exclude_filetypes=["pdf"]
392
+ )
393
+ await self.emit_event("search_results_fetched", {
394
+ "node_id": node_id,
395
+ "sub_query": sub_query,
396
+ "optimized_query": optimized_query,
397
+ "search_results": results
398
+ })
399
+ filtered_urls = await self.search_engine.filter_urls(
400
+ sub_query,
401
+ "extensive research dynamic structure",
402
+ results
403
+ )
404
+ await self.emit_event("search_results_filtered", {
405
+ "node_id": node_id,
406
+ "sub_query": sub_query,
407
+ "filtered_urls": filtered_urls
408
+ })
409
+ urls = [result.get('link', 'No URL') for result in filtered_urls]
410
+ search_contents = await self.custom_crawler.fetch_page_contents(
411
+ urls,
412
+ sub_query,
413
+ session_id=session_id,
414
+ max_attempts=1,
415
+ timeout=30
416
+ )
417
+ await self.emit_event("search_contents_fetched", {
418
+ "node_id": node_id,
419
+ "sub_query": sub_query,
420
+ "contents": search_contents
421
+ })
422
+
423
+ contents = ""
424
+ for k, content in enumerate(search_contents, 1):
425
+ if isinstance(content, Exception):
426
+ print(f"Error fetching content: {content}")
427
+ elif content:
428
+ contents += f"Document {k}:\n{content}\n\n"
429
+
430
+ if contents.strip():
431
+ if depth == 0:
432
+ await self.emit_event("sub_query_processed", {
433
+ "node_id": node_id,
434
+ "sub_query": sub_query,
435
+ "contents": contents
436
+ })
437
+
438
+ token_count = self.llm.get_num_tokens(contents)
439
+ if token_count > max_tokens_allowed:
440
+ contents = await self.chunking.chunker(
441
+ text=contents,
442
+ query=sub_query,
443
+ max_tokens=max_tokens_allowed
444
+ )
445
+ print(f"Number of tokens in the answer: {token_count}")
446
+ print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}")
447
+ else:
448
+ if depth == 0:
449
+ await self.emit_event("sub_query_failed", {
450
+ "node_id": node_id,
451
+ "sub_query": sub_query,
452
+ "contents": contents
453
+ })
454
+
455
+ graph_data = self._get_current_graph_data()
456
+ graph = graph_data["graph"]
457
+ node_map = graph_data["node_map"]
458
+
459
+ if node_id in node_map:
460
+ idx = node_map[node_id]
461
+ node_data = graph.get_node_data(idx)
462
+ node_data["data"] = contents
463
+ future.set_result(contents)
464
+ except Exception as e:
465
+ print(f"Error processing node {node_id}: {str(e)}")
466
+ future.set_exception(e)
467
+ raise
468
+
469
+ async def process_dependent_node(node_id: str, sub_query: str, depth, dep_futures: list, future):
470
+ try:
471
+ dep_data = [await f for f in dep_futures]
472
+ modified_query = await self.query_processor.modify_query(
473
+ sub_query,
474
+ dep_data
475
+ )
476
+ loop = asyncio.get_running_loop()
477
+ embedding = await loop.run_in_executor(
478
+ self.executor,
479
+ self.model.encode,
480
+ modified_query
481
+ )
482
+ graph_data = self._get_current_graph_data()
483
+ graph = graph_data["graph"]
484
+ node_map = graph_data["node_map"]
485
+
486
+ if node_id in node_map:
487
+ idx = node_map[node_id]
488
+ node_data = graph.get_node_data(idx)
489
+ node_data["query"] = modified_query
490
+ node_data["embedding"] = embedding.tolist() if hasattr(embedding, "tolist") else embedding
491
+ try:
492
+ if not future.done():
493
+ await process_node(node_id, modified_query, session_id, future, depth, max_tokens_allowed)
494
+ except Exception as e:
495
+ if not future.done():
496
+ future.set_exception(e)
497
+ raise
498
+ except Exception as e:
499
+ print(f"Error processing dependent node {node_id}: {str(e)}")
500
+ if not future.done():
501
+ future.set_exception(e)
502
+ raise
503
+
504
+ def create_cross_connections():
505
+ try:
506
+ relationships = self.get_node_relationships(relationship_type='logical')
507
+
508
+ for current_node_id, edges in relationships.items():
509
+ graph_data = self._get_current_graph_data()
510
+ graph = graph_data["graph"]
511
+ node_map = graph_data["node_map"]
512
+
513
+ if current_node_id not in node_map:
514
+ continue
515
+
516
+ idx = node_map[current_node_id]
517
+ node_data = graph.get_node_data(idx)
518
+ node_role = (node_data.get("role") or "").lower()
519
+
520
+ if node_role == 'dependent':
521
+ for source_id, target_id, edge_data in edges['in_edges']:
522
+ if not source_id or source_id == self.root_node_id:
523
+ continue
524
+
525
+ connection = tuple(sorted([current_node_id, source_id]))
526
+ if connection not in self.cross_connections:
527
+ if not self.edge_exists(source_id, current_node_id):
528
+ print(f"Adding cross-connection edge between {source_id} and {current_node_id}")
529
+ self.add_edge(source_id, current_node_id, weight=edge_data.get('weight', 1.0), relationship_type='logical')
530
+ self.cross_connections.add(connection)
531
+
532
+ for source_id, target_id, edge_data in edges['out_edges']:
533
+ if not target_id or target_id == self.root_node_id:
534
+ continue
535
+
536
+ connection = tuple(sorted([current_node_id, target_id]))
537
+ if connection not in self.cross_connections:
538
+ if not self.edge_exists(current_node_id, target_id):
539
+ print(f"Adding cross-connection edge between {current_node_id} and {target_id}")
540
+ self.add_edge(current_node_id, target_id, weight=edge_data.get('weight', 1.0), relationship_type='logical')
541
+ self.cross_connections.add(connection)
542
+ except Exception as e:
543
+ print(f"Error creating cross connections: {str(e)}")
544
+ raise
545
+
546
+ if depth > 1:
547
+ return
548
+
549
+ if context is None:
550
+ context = []
551
+ node_data_futures = {}
552
+
553
+ if parent_node_id is None:
554
+ self.add_node(self.root_node_id, query, data)
555
+ parent_node_id = self.root_node_id
556
+
557
+ intent = await self.query_processor.get_query_intent(query)
558
+
559
+ if depth == 0:
560
+ response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent)
561
+ else:
562
+ response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent, context)
563
+
564
+ if response_data:
565
+ context.append(response_data)
566
+
567
+ if len(sub_queries) > 1 and sub_queries[0] != query:
568
+ sub_query_ids = []
569
+ pre_req_nodes = {}
570
+
571
+ for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)):
572
+ if depth == 0:
573
+ await self.emit_event("sub_query_created", {
574
+ "depth": depth,
575
+ "sub_query": sub_query,
576
+ "role": role,
577
+ "dependency": dependency,
578
+ "parent_node_id": parent_node_id,
579
+ })
580
+
581
+ if depth == 0:
582
+ self.node_counter += 1
583
+ sub_node_id = f"SQ{self.node_counter}"
584
+ else:
585
+ self.sub_node_counter += 1
586
+ sub_node_id = f"SSQ{self.sub_node_counter}"
587
+
588
+ sub_query_ids.append(sub_node_id)
589
+ self.add_node(sub_node_id, sub_query, role=role)
590
+ future = asyncio.Future()
591
+ node_data_futures[sub_node_id] = future
592
+
593
+ if role.lower() in ['pre-requisite', 'prerequisite']:
594
+ pre_req_nodes[idx] = sub_node_id
595
+
596
+ if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
597
+ self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical')
598
+
599
+ elif role.lower() == 'dependent':
600
+ if isinstance(dependency, list) and (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)):
601
+ print(f"Dependency: {dependency}")
602
+ prev_deps, current_deps = dependency
603
+
604
+ if context and prev_deps not in [None, []]:
605
+ for dep_idx in prev_deps:
606
+ if dep_idx is not None:
607
+ matching_nodes = self.find_nodes_by_properties(query=dep_idx)
608
+
609
+ if matching_nodes not in [None, []]:
610
+ dep_node_id = matching_nodes[0].get('node_id')
611
+ score = matching_nodes[0].get('score', 0)
612
+
613
+ if score >= 0.9:
614
+ self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
615
+
616
+ if current_deps not in [None, []]:
617
+ for dep_idx in current_deps:
618
+ if dep_idx < len(sub_queries):
619
+ dep_node_id = sub_query_ids[dep_idx]
620
+ self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
621
+ else:
622
+ raise ValueError(f"Invalid dependency index: {dep_idx}")
623
+
624
+ elif len(dependency) > 0:
625
+ for dep_idx in dependency:
626
+ if dep_idx < len(sub_queries):
627
+ dep_node_id = sub_query_ids[dep_idx]
628
+ self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
629
+ else:
630
+ raise ValueError(f"Invalid dependency index: {dep_idx}")
631
+ else:
632
+ raise ValueError(f"Invalid dependency: {dependency}")
633
+ else:
634
+ raise ValueError(f"Unexpected role: {role}")
635
+
636
+ tasks = []
637
+ for idx in range(len(sub_queries)):
638
+ node_id = sub_query_ids[idx]
639
+ future = node_data_futures[node_id]
640
+
641
+ if roles[idx].lower() in ('pre-requisite', 'prerequisite', 'independent'):
642
+ tasks.append(process_node(node_id, sub_queries[idx], session_id, future, depth, max_tokens_allowed))
643
+
644
+ for idx in range(len(sub_queries)):
645
+ node_id = sub_query_ids[idx]
646
+ future = node_data_futures[node_id]
647
+
648
+ if roles[idx].lower() == 'dependent':
649
+ dep_futures = []
650
+
651
+ if isinstance(dependencies[idx], list) and len(dependencies[idx]) == 2:
652
+ prev_deps, current_deps = dependencies[idx]
653
+
654
+ if context and prev_deps not in [None, []]:
655
+ for context_idx, context_data in enumerate(context):
656
+ if isinstance(prev_deps, list) and context_idx < len(prev_deps):
657
+ context_dep = prev_deps[context_idx]
658
+
659
+ if context_dep is not None and isinstance(context_data, dict) and 'subqueries' in context_data:
660
+
661
+ if context_dep < len(context_data['subqueries']):
662
+ dep_query = context_data['subqueries'][context_dep]['subquery']
663
+ matching_nodes = self.find_nodes_by_properties(query=dep_query)
664
+
665
+ if matching_nodes not in [None, []]:
666
+ dep_node_id = matching_nodes[0].get('node_id', None)
667
+ score = float(matching_nodes[0].get('score', 0))
668
+
669
+ if score == 1.0 and dep_node_id in node_data_futures:
670
+ dep_futures.append(node_data_futures[dep_node_id])
671
+
672
+ elif isinstance(prev_deps, int):
673
+ if prev_deps < len(context_data['subqueries']):
674
+ dep_query = context_data['subqueries'][prev_deps]['subquery']
675
+ matching_nodes = self.find_nodes_by_properties(query=dep_query)
676
+
677
+ if matching_nodes not in [None, []]:
678
+ dep_node_id = matching_nodes[0].get('node_id', None)
679
+ score = matching_nodes[0].get('score', 0)
680
+
681
+ if score == 1.0 and dep_node_id in node_data_futures:
682
+ dep_futures.append(node_data_futures[dep_node_id])
683
+
684
+ if current_deps not in [None, []]:
685
+ current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
686
+
687
+ for dep_idx in current_deps_list:
688
+ if dep_idx < len(sub_queries):
689
+ dep_node_id = sub_query_ids[dep_idx]
690
+ if dep_node_id in node_data_futures:
691
+ dep_futures.append(node_data_futures[dep_node_id])
692
+
693
+ tasks.append(process_dependent_node(node_id, sub_queries[idx], depth, dep_futures, future))
694
+
695
+ if depth == 0:
696
+ await self.emit_event("search_process_started", {
697
+ "depth": depth,
698
+ "sub_queries": sub_queries,
699
+ "roles": roles
700
+ })
701
+
702
+ await asyncio.gather(*tasks)
703
+
704
+ if recurse:
705
+ recursion_tasks = []
706
+
707
+ for idx, sub_query in enumerate(sub_queries):
708
+ try:
709
+ sub_node_id = sub_query_ids[idx]
710
+ recursion_tasks.append(
711
+ self.build_graph(
712
+ query=sub_query,
713
+ parent_node_id=sub_node_id,
714
+ depth=depth + 1,
715
+ threshold=threshold,
716
+ recurse=recurse,
717
+ context=context,
718
+ session_id=session_id
719
+ )
720
+ )
721
+ except Exception as e:
722
+ print(f"Failed to create recursion task for sub-query {sub_query}: {e}")
723
+ continue
724
+
725
+ if recursion_tasks:
726
+ try:
727
+ await asyncio.gather(*recursion_tasks)
728
+ except Exception as e:
729
+ raise Exception(f"Error during recursive processing: {e}")
730
+
731
+ if depth == 0:
732
+ print("Graph building complete, processing final tasks...")
733
+ await self.emit_event("search_process_completed", {
734
+ "depth": depth,
735
+ "sub_queries": sub_queries,
736
+ "roles": roles
737
+ })
738
+
739
+ create_cross_connections()
740
+ print("All cross-connections have been created!")
741
+ print(f"Adding similarity edges with threshold {threshold}")
742
+
743
+ graph_data = self._get_current_graph_data()
744
+ node_map = graph_data["node_map"]
745
+ all_node_ids = list(node_map.keys())
746
+
747
+ for i, node1 in enumerate(all_node_ids):
748
+ for node2 in all_node_ids[i+1:]:
749
+ if not self.edge_exists(node1, node2):
750
+ self.add_edge_based_on_similarity_and_relevance(node1, node2, query, threshold)
751
+
752
+ print("All similarity edges have been added!")
753
+
754
+ async def process_graph(
755
+ self,
756
+ query: str,
757
+ data: str = None,
758
+ similarity_threshold: float = 0.8,
759
+ relevance_threshold: float = 0.7,
760
+ sub_sub_queries: bool = True,
761
+ session_id: str = None,
762
+ max_tokens_allowed: int = 128000
763
+ ):
764
+ """Process a query and manage graph creation/modification."""
765
+ def check_query_similarity(new_query: str, similarity_threshold: float = 0.8) -> Dict[str, Any]:
766
+ if self.current_graph_id is None:
767
+ raise Exception("Error: No current graph ID. Cannot check query similarity.")
768
+
769
+ graph_data = self._get_current_graph_data()
770
+ graph = graph_data["graph"]
771
+ node_map = graph_data["node_map"]
772
+ similarities = []
773
+
774
+ if not node_map:
775
+ return {"should_create_new": True}
776
+
777
+ for node_id, idx in node_map.items():
778
+ node_data = graph.get_node_data(idx)
779
+
780
+ if not node_data.get("query"):
781
+ continue
782
+
783
+ similarity = self.calculate_query_similarity(new_query, node_data.get("query"))
784
+ if similarity >= similarity_threshold:
785
+ similarities.append({
786
+ "node_id": node_id,
787
+ "query": node_data.get("query"),
788
+ "score": similarity,
789
+ "role": node_data.get("role")
790
+ })
791
+
792
+ if not similarities:
793
+ print(f"No similar queries found above threshold {similarity_threshold}")
794
+ return {"should_create_new": True}
795
+
796
+ best_match = max(similarities, key=lambda x: x["score"])
797
+
798
+ rel_type = "root"
799
+ if "SSQ" in best_match["node_id"]:
800
+ rel_type = "sub-sub"
801
+
802
+ elif "SQ" in best_match["node_id"]:
803
+ rel_type = "sub"
804
+
805
+ return {
806
+ "most_similar_query": best_match["query"],
807
+ "similarity_score": best_match["score"],
808
+ "relationship_type": rel_type,
809
+ "node_id": best_match["node_id"],
810
+ "should_create_new": best_match["score"] < similarity_threshold
811
+ }
812
+ try:
813
+ graphs = self.get_graphs()
814
+
815
+ if not graphs:
816
+ print("No existing graphs found. Creating new graph.")
817
+ self.create_new_graph()
818
+ await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"})
819
+ await self.build_graph(
820
+ query=query,
821
+ data=data,
822
+ threshold=relevance_threshold,
823
+ recurse=sub_sub_queries,
824
+ session_id=session_id,
825
+ max_tokens_allowed=max_tokens_allowed
826
+ )
827
+ gc.collect()
828
+ self.prune_edges()
829
+ self.update_pagerank()
830
+ self.verify_graph_integrity()
831
+ self.verify_graph_consistency()
832
+ return
833
+
834
+ max_similarity = 0
835
+ most_similar_graph = None
836
+ consolidated_graphs = {}
837
+
838
+ for graph_obj in graphs:
839
+ graph_info = graph_obj.get("graph_info")
840
+ if not graph_info:
841
+ continue
842
+
843
+ graph_id = graph_info.get("graph_id")
844
+
845
+ if not graph_id:
846
+ continue
847
+
848
+ if graph_id not in consolidated_graphs:
849
+ consolidated_graphs[graph_id] = {
850
+ "graph_id": graph_id,
851
+ "nodes": []
852
+ }
853
+
854
+ if graph_info.get("nodes"):
855
+ consolidated_graphs[graph_id]["nodes"].extend(graph_info["nodes"])
856
+
857
+ for graph_id, graph_data in consolidated_graphs.items():
858
+ nodes = graph_data["nodes"]
859
+
860
+ for node in nodes:
861
+ if node.get("query"):
862
+ similarity = self.calculate_query_similarity(query, node["query"])
863
+
864
+ if node.get("id", "").startswith("SQ"):
865
+ asyncio.create_task(self.emit_event("retrieved_sub_query", {
866
+ "sub_query": node["query"]
867
+ }))
868
+
869
+ if similarity > max_similarity:
870
+ max_similarity = similarity
871
+ most_similar_graph = graph_id
872
+
873
+ if max_similarity >= similarity_threshold:
874
+ print(f"Found similar query with score {round(max_similarity, 2)}")
875
+ self.current_graph_id = most_similar_graph
876
+
877
+ if round(max_similarity, 2) == 1.0:
878
+ print("Loading and using existing graph")
879
+ await self.emit_event("graph_operation", {"operation_type": "loading_existing_graph"})
880
+ success = self.load_graph(self.root_node_id)
881
+
882
+ if not success:
883
+ raise Exception("Failed to load existing graph")
884
+
885
+ else:
886
+ print("Checking for node-level similarity...")
887
+ similarity_info = check_query_similarity(
888
+ query,
889
+ similarity_threshold
890
+ )
891
+
892
+ if similarity_info["relationship_type"] in ["sub", "sub-sub"]:
893
+ print(f"Most Similar Query: {similarity_info['most_similar_query']}")
894
+ print("Modifying existing graph structure")
895
+ await self.emit_event("graph_operation", {"operation_type": "modifying_existing_graph"})
896
+ await self.modify_graph(
897
+ query,
898
+ similarity_info["node_id"],
899
+ session_id=session_id
900
+ )
901
+ gc.collect()
902
+ self.prune_edges()
903
+ self.update_pagerank()
904
+ self.verify_graph_integrity()
905
+ self.verify_graph_consistency()
906
+
907
+ else:
908
+ print(f"Creating new graph for query: {query}")
909
+ self.create_new_graph()
910
+ await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"})
911
+ await self.build_graph(
912
+ query=query,
913
+ data=data,
914
+ threshold=relevance_threshold,
915
+ recurse=sub_sub_queries,
916
+ session_id=session_id,
917
+ max_tokens_allowed=max_tokens_allowed
918
+ )
919
+ gc.collect()
920
+ self.prune_edges()
921
+ self.update_pagerank()
922
+ self.verify_graph_integrity()
923
+ self.verify_graph_consistency()
924
+ except Exception as e:
925
+ print(f"Error in process_graph: {str(e)}")
926
+ raise
927
+
928
+ def add_edge_based_on_similarity_and_relevance(self, node1_id: str, node2_id: str, query: str, threshold: float = 0.8):
929
+ """Add edges based on node similarity and relevance."""
930
+ graph_data = self._get_current_graph_data()
931
+ graph = graph_data["graph"]
932
+ node_map = graph_data["node_map"]
933
+
934
+ if node1_id not in node_map or node2_id not in node_map:
935
+ return
936
+
937
+ idx1 = node_map[node1_id]
938
+ idx2 = node_map[node2_id]
939
+ node1_data = graph.get_node_data(idx1)
940
+ node2_data = graph.get_node_data(idx2)
941
+
942
+ if not all([node1_data.get("embedding"), node2_data.get("embedding"), node1_data.get("data"), node2_data.get("data")]):
943
+ return
944
+
945
+ similarity = self.cosine_similarity(node1_data["embedding"], node2_data["embedding"])
946
+ query_relevance1 = self.calculate_relevance(query, node1_data["data"])
947
+ query_relevance2 = self.calculate_relevance(query, node2_data["data"])
948
+ node_relevance = self.calculate_relevance(node1_data["data"], node2_data["data"])
949
+ weight = (similarity + query_relevance1 + query_relevance2 + node_relevance) / 4
950
+
951
+ if weight >= threshold:
952
+ self.add_edge(node1_id, node2_id, weight=weight, relationship_type='similarity_and_relevance')
953
+ print(f"Added edge between {node1_id} and {node2_id} with type similarity_and_relevance and weight {weight}")
954
+
955
+ def calculate_relevance(self, data1: str, data2: str) -> float:
956
+ """Calculate relevance between two data strings."""
957
+ try:
958
+ if not data1 or not data2:
959
+ return 0.0
960
+
961
+ P, R, F1 = self.scorer.score([data1], [data2])
962
+ return F1.mean().item()
963
+ except Exception as e:
964
+ print(f"Error calculating relevance: {str(e)}")
965
+ return 0.0
966
+
967
+ def calculate_query_similarity(self, query1: str, query2: str) -> float:
968
+ """Calculate similarity between two queries."""
969
+ try:
970
+ embedding1 = self.model.encode(query1).tolist()
971
+ embedding2 = self.model.encode(query2).tolist()
972
+ return self.cosine_similarity(embedding1, embedding2)
973
+ except Exception as e:
974
+ print(f"Error calculating query similarity: {str(e)}")
975
+ return 0.0
976
+
977
+ def get_similarities_and_relevance(self, threshold: float = 0.8) -> list:
978
+ """Get similarities and relevance between nodes."""
979
+ try:
980
+ graph_data = self._get_current_graph_data()
981
+ graph = graph_data["graph"]
982
+ node_map = graph_data["node_map"]
983
+ nodes = []
984
+
985
+ for node_id, idx in node_map.items():
986
+ node_data = graph.get_node_data(idx)
987
+ nodes.append({
988
+ "id": node_data.get("id"),
989
+ "embedding": node_data.get("embedding"),
990
+ "data": node_data.get("data")
991
+ })
992
+
993
+ similarities = []
994
+ for i, node1 in enumerate(nodes):
995
+ for node2 in nodes[i + 1:]:
996
+ similarity = self.cosine_similarity(node1["embedding"], node2["embedding"])
997
+ relevance = self.calculate_relevance(node1["data"], node2["data"])
998
+ weight = (similarity + relevance) / 2
999
+
1000
+ if weight >= threshold:
1001
+ similarities.append({
1002
+ 'node1': node1["id"],
1003
+ 'node2': node2["id"],
1004
+ 'similarity': similarity,
1005
+ 'relevance': relevance,
1006
+ 'weight': weight
1007
+ })
1008
+
1009
+ return similarities
1010
+ except Exception as e:
1011
+ print(f"Error getting similarities and relevance: {str(e)}")
1012
+ return []
1013
+
1014
+ def get_node_relationships(self, node_id=None, depth=None, role=None, relationship_type=None):
1015
+ """Get relationships between nodes with filtering options."""
1016
+ graph_data = self._get_current_graph_data()
1017
+ graph = graph_data["graph"]
1018
+ node_map = graph_data["node_map"]
1019
+ relationships = {}
1020
+
1021
+ for n_id, idx in node_map.items():
1022
+ if n_id == self.root_node_id:
1023
+ continue
1024
+
1025
+ node_data = graph.get_node_data(idx)
1026
+
1027
+ if node_id and n_id != node_id:
1028
+ continue
1029
+
1030
+ if role and node_data.get("role") != role:
1031
+ continue
1032
+
1033
+ in_edges = []
1034
+ for u, v, edge_data in graph.in_edges(idx):
1035
+ source_id = graph.get_node_data(u).get("id")
1036
+ in_edges.append((source_id, n_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")}))
1037
+
1038
+ out_edges = []
1039
+ for u, v, edge_data in graph.out_edges(idx):
1040
+ target_id = graph.get_node_data(v).get("id")
1041
+ out_edges.append((n_id, target_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")}))
1042
+
1043
+ relationships[n_id] = {"in_edges": in_edges, "out_edges": out_edges}
1044
+
1045
+ return relationships
1046
+
1047
+ def find_nodes_by_properties(self, query: str = None, embedding: list = None,
1048
+ node_data: dict = None, similarity_threshold: float = 0.8) -> list:
1049
+ """Find nodes based on properties."""
1050
+ try:
1051
+ graph_data = self._get_current_graph_data()
1052
+ graph = graph_data["graph"]
1053
+ node_map = graph_data["node_map"]
1054
+ matching_nodes = []
1055
+
1056
+ for n_id, idx in node_map.items():
1057
+ data = graph.get_node_data(idx)
1058
+ match_score = 0
1059
+ matches = 0
1060
+
1061
+ if query and query.lower() in data.get("query", "").lower():
1062
+ match_score += 1
1063
+ matches += 1
1064
+
1065
+ if embedding and "embedding" in data:
1066
+ sim = self.cosine_similarity(embedding, data["embedding"])
1067
+
1068
+ if sim >= similarity_threshold:
1069
+ match_score += sim
1070
+ matches += 1
1071
+
1072
+ if node_data:
1073
+ data_matches = sum(1 for k, v in node_data.items() if k in data and data[k] == v)
1074
+
1075
+ if data_matches > 0:
1076
+ match_score += data_matches / len(node_data)
1077
+ matches += 1
1078
+
1079
+ if matches > 0:
1080
+ matching_nodes.append({
1081
+ "node_id": n_id,
1082
+ "score": match_score / matches,
1083
+ "data": data
1084
+ })
1085
+
1086
+ matching_nodes.sort(key=lambda x: x["score"], reverse=True)
1087
+
1088
+ return matching_nodes
1089
+ except Exception as e:
1090
+ print(f"Error finding nodes by properties: {str(e)}")
1091
+ raise
1092
+
1093
+ def query_graph(self, query: str) -> str:
1094
+ """Query the graph for a specific query, collecting data from the entire relevant subgraph."""
1095
+ graph_data = self._get_current_graph_data()
1096
+ graph = graph_data["graph"]
1097
+ node_map = graph_data["node_map"]
1098
+ target_node_id = None
1099
+
1100
+ for n_id, idx in node_map.items():
1101
+ if graph.get_node_data(idx).get("query") == query:
1102
+ target_node_id = n_id
1103
+ break
1104
+
1105
+ if not target_node_id:
1106
+ raise ValueError(f"Query node not found for: {query}")
1107
+
1108
+ datas = []
1109
+ start_idx = node_map[target_node_id]
1110
+ visited = set()
1111
+ stack = [start_idx]
1112
+
1113
+ while stack:
1114
+ current = stack.pop()
1115
+
1116
+ if current in visited:
1117
+ continue
1118
+ visited.add(current)
1119
+ current_data = graph.get_node_data(current)
1120
+
1121
+ if current_data.get("data") and current_data.get("data").strip():
1122
+ datas.append(current_data.get("data").strip())
1123
+
1124
+ for neighbor in graph.neighbors(current):
1125
+ if neighbor not in visited:
1126
+ stack.append(neighbor)
1127
+
1128
+ if not datas:
1129
+ print(f"No data found for: {query}")
1130
+ return ""
1131
+
1132
+ return "\n\n".join([f"Data {i+1}:\n{data}" for i, data in enumerate(datas)])
1133
+
1134
+ def prune_edges(self, max_edges: int = 1000):
1135
+ """Prune excess edges while preserving node data."""
1136
+ print(f"Pruning edges to maximum {max_edges} edges...")
1137
+ graph_data = self._get_current_graph_data()
1138
+ graph = graph_data["graph"]
1139
+ all_edges = list(graph.edge_list())
1140
+ current_edges = len(all_edges)
1141
+
1142
+ if current_edges > max_edges:
1143
+ sorted_edges = sorted(all_edges, key=lambda x: x[2].get("weight", 1.0), reverse=True)
1144
+ edges_to_keep = set()
1145
+
1146
+ for edge in sorted_edges[:max_edges]:
1147
+ edges_to_keep.add((edge[0], edge[1]))
1148
+
1149
+ edges_to_remove = []
1150
+ for edge in all_edges:
1151
+ if (edge[0], edge[1]) not in edges_to_keep:
1152
+ edges_to_remove.append((edge[0], edge[1]))
1153
+
1154
+ for u, v in edges_to_remove:
1155
+ try:
1156
+ graph.remove_edge(u, v)
1157
+ except Exception as e:
1158
+ print(f"Error removing edge from {u} to {v}: {e}")
1159
+
1160
+ print(f"Pruned edges. Kept top {max_edges} edges by weight.")
1161
+
1162
+ print("No pruning required. Current edge count is within limits.")
1163
+
1164
+ def update_pagerank(self):
1165
+ """Update PageRank values using Rustworkx's pagerank algorithm."""
1166
+ if not self.current_graph_id:
1167
+ print("No current graph selected. Cannot compute PageRank.")
1168
+ return
1169
+
1170
+ graph_data = self._get_current_graph_data()
1171
+ graph = graph_data["graph"]
1172
+
1173
+ try:
1174
+ pr = rx.pagerank(graph, weight_fn=lambda e: e.get("weight", 1.0))
1175
+ node_map = graph_data["node_map"]
1176
+
1177
+ for n_id, idx in node_map.items():
1178
+ node_data = graph.get_node_data(idx)
1179
+ node_data["pagerank"] = pr[idx]
1180
+
1181
+ print("PageRank updated successfully")
1182
+ except Exception as e:
1183
+ print(f"Error updating PageRank: {str(e)}")
1184
+ raise
1185
+
1186
+ def display_graph(self):
1187
+ """Display the graph using PyVis."""
1188
+ graph_data = self._get_current_graph_data()
1189
+ graph = graph_data["graph"]
1190
+ node_map = graph_data["node_map"]
1191
+ net = Network(height="600px", width="100%", directed=True, bgcolor="#222222", font_color="white")
1192
+ net.options = {"physics": {"enabled": False}}
1193
+ all_nodes = set()
1194
+ all_edges = []
1195
+
1196
+ for (u, v), edge_data in zip(graph.edge_list(), graph.edges()):
1197
+ source_data = graph.get_node_data(u)
1198
+ target_data = graph.get_node_data(v)
1199
+ source_id = source_data.get("id")
1200
+ target_id = target_data.get("id")
1201
+ source_tooltip = f"Query: {source_data.get('query', 'N/A')}"
1202
+ target_tooltip = f"Query: {target_data.get('query', 'N/A')}"
1203
+
1204
+ if source_id not in all_nodes:
1205
+ net.add_node(source_id, label=source_id, title=source_tooltip, size=20, color="#00cc66")
1206
+ all_nodes.add(source_id)
1207
+
1208
+ if target_id not in all_nodes:
1209
+ net.add_node(target_id, label=target_id, title=target_tooltip, size=20, color="#00cc66")
1210
+ all_nodes.add(target_id)
1211
+
1212
+ edge_type = edge_data.get("type", "N/A")
1213
+ edge_weight = edge_data.get("weight", "N/A")
1214
+ edge_tooltip = f"Weight: {edge_weight}"
1215
+ all_edges.append({
1216
+ "from": source_id,
1217
+ "to": target_id,
1218
+ "label": edge_type,
1219
+ "title": edge_tooltip
1220
+ })
1221
+
1222
+ for edge in all_edges:
1223
+ net.add_edge(edge["from"], edge["to"], title=edge["title"], color="#cccccc")
1224
+
1225
+ net.options["layout"] = {"improvedLayout": True}
1226
+ net.options["interaction"] = {"dragNodes": True}
1227
+
1228
+ net.save_graph("temp_graph.html")
1229
+
1230
+ with open("temp_graph.html", "r", encoding="utf-8") as f:
1231
+ html_str = f.read()
1232
+ os.remove("temp_graph.html")
1233
+ return html_str
1234
+
1235
+ def verify_graph_integrity(self):
1236
+ """Verify and fix graph integrity issues."""
1237
+ graph_data = self._get_current_graph_data()
1238
+ graph = graph_data["graph"]
1239
+ node_map = graph_data["node_map"]
1240
+ orphaned = []
1241
+
1242
+ for n_id, idx in node_map.items():
1243
+ if not graph.in_edges(idx) and not graph.out_edges(idx):
1244
+ orphaned.append(n_id)
1245
+
1246
+ if orphaned:
1247
+ print(f"Found orphaned nodes: {orphaned}")
1248
+
1249
+ invalid_edges = []
1250
+ for u, v in graph.edge_list():
1251
+ target_data = graph.get_node_data(v)
1252
+
1253
+ if target_data.get("graph_id") != self.current_graph_id:
1254
+ invalid_edges.append((graph.get_node_data(u).get("id"), target_data.get("id")))
1255
+
1256
+ if invalid_edges:
1257
+ print(f"Found invalid edges: {invalid_edges}")
1258
+ edges_to_remove = []
1259
+
1260
+ for u, v in graph.edge_list():
1261
+ if graph.get_node_data(v).get("graph_id") != self.current_graph_id:
1262
+ edges_to_remove.append((u, v))
1263
+
1264
+ for u, v in edges_to_remove:
1265
+ try:
1266
+ graph.remove_edge(u, v)
1267
+ except Exception as e:
1268
+ Exception(f"Error removing invalid edge from {u} to {v}: {e}")
1269
+
1270
+ print("Graph integrity verified successfully")
1271
+
1272
+ return True
1273
+
1274
+ def verify_graph_consistency(self):
1275
+ """Verify consistency of the in-memory graph."""
1276
+ graph_data = self._get_current_graph_data()
1277
+ graph = graph_data["graph"]
1278
+ node_map = graph_data["node_map"]
1279
+
1280
+ for n_id, idx in node_map.items():
1281
+ node_data = graph.get_node_data(idx)
1282
+
1283
+ if node_data.get("id") is None or node_data.get("query") is None:
1284
+ raise ValueError("Found nodes with missing required properties")
1285
+
1286
+ for edge_data in graph.edges():
1287
+ if edge_data.get("type") is None or edge_data.get("weight") is None:
1288
+ raise ValueError("Found relationships with missing required properties")
1289
+
1290
+ print("Graph consistency verified successfully")
1291
+
1292
+ return True
1293
+
1294
+ async def close(self):
1295
+ """Properly cleanup all resources."""
1296
+ try:
1297
+ if hasattr(self, 'executor'):
1298
+ self.executor.shutdown(wait=True)
1299
+
1300
+ if hasattr(self, 'crawler'):
1301
+ await asyncio.shield(self.crawler.cleanup_expired_sessions())
1302
+ await asyncio.shield(self.crawler.cleanup_browser_context(getattr(self, "session_id", None)))
1303
+ except Exception as e:
1304
+ print(f"Error during cleanup: {e}")
1305
+
1306
+ @staticmethod
1307
+ def cosine_similarity(v1: List[float], v2: List[float]) -> float:
1308
+ """Calculate cosine similarity between two vectors."""
1309
+ try:
1310
+ v1_array = np.array(v1)
1311
+ v2_array = np.array(v2)
1312
+ return np.dot(v1_array, v2_array) / (np.linalg.norm(v1_array) * np.linalg.norm(v2_array))
1313
+ except Exception as e:
1314
+ print(f"Error calculating cosine similarity: {str(e)}")
1315
+ return 0.0
1316
+
1317
+ if __name__ == "__main__":
1318
+ import os
1319
+ from dotenv import load_dotenv
1320
+ from src.reasoning.reasoner import Reasoner
1321
+ from src.evaluation.evaluator import Evaluator
1322
+
1323
+ load_dotenv()
1324
+
1325
+ graph_search = GraphRAG(num_workers=24)
1326
+ evaluator = Evaluator()
1327
+ reasoner = Reasoner()
1328
+
1329
+ async def test_graph_search():
1330
+ # Sample data for testing
1331
+ queries = [
1332
+ """In the context of global economic recovery and energy security concerns, provide an in-depth comparative assessment of the renewable energy policies among G20 countries.
1333
+ Specifically, examine how short-term economic stimulus measures intersect with long-term decarbonization commitments, including:
1334
+ 1. Carbon pricing mechanisms
1335
+ 2. Subsidies for emerging technologies (such as green hydrogen and battery storage)
1336
+ 3. Cross-border climate finance initiatives
1337
+
1338
+ Highlight the unique challenges faced by both advanced and emerging economies in addressing:
1339
+ 1. Energy poverty
1340
+ 2. Supply chain disruptions
1341
+ 3. Geopolitical tensions (e.g., the Russia-Ukraine conflict)
1342
+
1343
+ Discuss how these factors influence policy effectiveness, and evaluate the degree to which each country is on track to meet—or exceed—its Paris Agreement targets.
1344
+ Note any significant policy gaps, regional collaborations, or innovative best practices.
1345
+ Lastly, provide a forward-looking perspective on how these renewable energy strategies may evolve over the next decade, considering:
1346
+ 1. Technological breakthroughs
1347
+ 2. Global market trends
1348
+ 3. Potential climate-related disasters
1349
+
1350
+ Present your analysis as a detailed, well-formatted report.""",
1351
+ """Analyse the impact of 'hot-money' on the value of Indian Rupee and answer the following questions:-
1352
+ 1. How does it affect the exchange rate?
1353
+ 2. How can it be mitigated/eliminated?
1354
+ 3. Why is it a problem?
1355
+ 4. What are the consequences?
1356
+ 5. What are the alternatives?
1357
+ - Evaluate the alternatives for pros and cons.
1358
+ - Evaluate the impact of alternatives on the exchange rate.
1359
+ - How can they be implemented?
1360
+ - What are the consequences of each alternative?
1361
+ - Evaluate the feasibility of the alternatives.
1362
+ - Pick top 5 alternatives and justify your choices in detail.
1363
+ 6. What are the implications for the Indian economy? Furthermore:-
1364
+ - Evaluate the impact of the chosen alternatives on the Indian economy.""",
1365
+ """Inflation has been an intrinsic past of human civilization since the very beginning. Answer the following questions:-
1366
+ 1. How true is the above statement?
1367
+ 2. What are the causes of inflation?
1368
+ 3. What are the consequences of inflation?
1369
+ 4. Can we completely eliminate inflation?""",
1370
+ """Perform a detailed comparison between the ancient Greece and Roman civilizations.
1371
+ 1. What were the key differences between the two civilizations?
1372
+ - Evaluate the differences in governance, society, and culture
1373
+ - Evaluate the differences in economy, trade, and military
1374
+ - Evaluate the differences in technology and infrastructure
1375
+ 2. What were the similarities between the two civilizations?
1376
+ - Evaluate the similarities in governance, society, and culture
1377
+ - Evaluate the similarities in economy, trade, and military
1378
+ - Evaluate the similarities in technology and infrastructure
1379
+ 3. How did these two civilizations influence each other?
1380
+ - Evaluate the influence of one civilization on the other
1381
+ 4. How did these two civilizations influence the modern world?
1382
+ 5. Was there another civilization that influenced these two? If yes, how?""",
1383
+ """Evaluate the long-term effects of colonialism on economic development in Asia:-
1384
+ 1. Include case studies of at least five different countries
1385
+ 2. Analyze how these effects differ based on colonial power, time of independence, and resource distribution
1386
+ - Evaluate the impact of colonialism on the economy of the country
1387
+ - Evaluate the impact of colonialism on the economy of the region
1388
+ - Evaluate the impact of colonialism on the economy of the world
1389
+ 3. How do these effects compare to Africa?"""
1390
+ ]
1391
+ follow_on_queries = [
1392
+ "How is 'hot-money' related to the current economic situation in India?",
1393
+ "What is inflation?",
1394
+ "Did ancient Greece and Rome have any impact on modern democracy? If yes, how?",
1395
+ "Did colonialism have any impact on the trade between Africa and Asia, both in colonial and post-colonial times? If yes, how?"
1396
+ ]
1397
+
1398
+ while True:
1399
+ print("\n\nEnter query (finish input with an empty line):")
1400
+ query_lines = []
1401
+
1402
+ while True:
1403
+ line = input()
1404
+
1405
+ if line.strip() == "":
1406
+ break
1407
+ query_lines.append(line)
1408
+
1409
+ query = "\n".join(query_lines).strip()
1410
+
1411
+ if query.strip().lower() == "exit":
1412
+ break
1413
+ print("\n\n" + "="*15 + " Processing Query " + "="*15 + "\n\n")
1414
+
1415
+ await graph_search.process_graph(query, similarity_threshold=0.8, relevance_threshold=0.8)
1416
+
1417
+ answer = graph_search.query_graph(query)
1418
+
1419
+ response = ""
1420
+ async for chunk in reasoner.reason(query, answer):
1421
+ response += chunk
1422
+ print(response, end="", flush=True)
1423
+
1424
+ graph_search.display_graph()
1425
+
1426
+ evaluation = await evaluator.evaluate_response(query, response, [answer])
1427
+ print(f"Faithfulness: {evaluation['faithfulness']}")
1428
+ print(f"Answer Relevancy: {evaluation['answer relevancy']}")
1429
+ print(f"Context Utilization: {evaluation['contextual recall']}")
1430
+
1431
+ await graph_search.close()
1432
+
1433
+ asyncio.run(test_graph_search())