|
|
|
|
|
|
|
import numpy as np
|
|
from mcts import MCTS
|
|
from ranking import train_ranking_model
|
|
from bs4 import BeautifulSoup
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from collections import deque, OrderedDict
|
|
import random
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
from sentence_transformers import SentenceTransformer
|
|
import hashlib
|
|
from twisted.internet import defer
|
|
import logging
|
|
import json
|
|
import os
|
|
from urllib.parse import urlparse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class SumTree:
|
|
"""
|
|
SumTree data structure where the parent’s value is the sum of its children.
|
|
Leaf nodes contain the priorities of experiences.
|
|
"""
|
|
def __init__(self, capacity):
|
|
self.capacity = capacity
|
|
self.tree = np.zeros(2 * capacity - 1)
|
|
self.data = np.zeros(capacity, dtype=object)
|
|
self.write = 0
|
|
self.n_entries = 0
|
|
|
|
def _propagate(self, idx, change):
|
|
parent = (idx - 1) // 2
|
|
self.tree[parent] += change
|
|
if parent != 0:
|
|
self._propagate(parent, change)
|
|
|
|
def _retrieve(self, idx, s):
|
|
left = 2 * idx + 1
|
|
right = left + 1
|
|
|
|
if left >= len(self.tree):
|
|
return idx
|
|
|
|
if s <= self.tree[left]:
|
|
return self._retrieve(left, s)
|
|
else:
|
|
return self._retrieve(right, s - self.tree[left])
|
|
|
|
def total(self):
|
|
return self.tree[0]
|
|
|
|
def add(self, p, data):
|
|
idx = self.write + self.capacity - 1
|
|
|
|
self.data[self.write] = data
|
|
self.update(idx, p)
|
|
|
|
self.write += 1
|
|
if self.write >= self.capacity:
|
|
self.write = 0
|
|
|
|
if self.n_entries < self.capacity:
|
|
self.n_entries += 1
|
|
|
|
def update(self, idx, p):
|
|
change = p - self.tree[idx]
|
|
self.tree[idx] = p
|
|
self._propagate(idx, change)
|
|
|
|
def get(self, s):
|
|
idx = self._retrieve(0, s)
|
|
data_idx = idx - self.capacity + 1
|
|
|
|
return (idx, self.tree[idx], self.data[data_idx])
|
|
|
|
class PrioritizedReplayMemory:
|
|
def __init__(self, capacity, alpha=0.6):
|
|
self.tree = SumTree(capacity)
|
|
self.alpha = alpha
|
|
self.epsilon = 1e-6
|
|
|
|
def add(self, error, sample):
|
|
p = (np.abs(error) + self.epsilon) ** self.alpha
|
|
self.tree.add(p, sample)
|
|
|
|
def sample(self, batch_size, beta=0.4):
|
|
batch = []
|
|
idxs = []
|
|
segment = self.tree.total() / batch_size
|
|
priorities = []
|
|
|
|
for i in range(batch_size):
|
|
a = segment * i
|
|
b = segment * (i + 1)
|
|
s = random.uniform(a, b)
|
|
idx, p, data = self.tree.get(s)
|
|
batch.append(data)
|
|
idxs.append(idx)
|
|
priorities.append(p)
|
|
|
|
total = self.tree.total()
|
|
probs = priorities / total
|
|
weights = (self.tree.n_entries * probs) ** (-beta)
|
|
weights /= weights.max()
|
|
return batch, idxs, weights
|
|
|
|
def update(self, idx, error):
|
|
p = (np.abs(error) + self.epsilon) ** self.alpha
|
|
self.tree.update(idx, p)
|
|
|
|
|
|
|
|
|
|
|
|
class ManagerModel(nn.Module):
|
|
"""
|
|
High-level policy model (Manager) that decides which option to execute.
|
|
"""
|
|
def __init__(self, input_size, hidden_size, num_options):
|
|
super(ManagerModel, self).__init__()
|
|
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
|
self.fc = nn.Linear(hidden_size, num_options)
|
|
self.layer_norm = nn.LayerNorm(hidden_size)
|
|
|
|
def forward(self, x, hidden=None):
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(1)
|
|
out, hidden = self.lstm(x, hidden)
|
|
last_output = out[:, -1, :]
|
|
last_output = self.layer_norm(last_output)
|
|
option_scores = self.fc(last_output)
|
|
return option_scores, hidden
|
|
|
|
class WorkerModel(nn.Module):
|
|
"""
|
|
Low-level policy model (Worker) that executes actions based on the selected option.
|
|
"""
|
|
def __init__(self, input_size, hidden_size, action_size):
|
|
super(WorkerModel, self).__init__()
|
|
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
|
self.fc = nn.Linear(hidden_size, action_size)
|
|
self.layer_norm = nn.LayerNorm(hidden_size)
|
|
self.action_size = action_size
|
|
|
|
def forward(self, x, hidden=None):
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(1)
|
|
out, hidden = self.lstm(x, hidden)
|
|
last_output = out[:, -1, :]
|
|
last_output = self.layer_norm(last_output)
|
|
action_scores = self.fc(last_output)
|
|
return action_scores, hidden
|
|
|
|
def act(self, state, epsilon=0.1):
|
|
"""
|
|
Selects an action using epsilon-greedy policy.
|
|
"""
|
|
if random.random() < epsilon:
|
|
action = random.randint(0, self.action_size - 1)
|
|
return action
|
|
state = torch.FloatTensor(state).unsqueeze(0).to(next(self.parameters()).device)
|
|
with torch.no_grad():
|
|
action_scores, _ = self(state)
|
|
action = torch.argmax(action_scores, dim=1).item()
|
|
return action
|
|
|
|
|
|
|
|
|
|
|
|
class RAGSummarizer:
|
|
def __init__(self, model_name='gpt2', embedding_model='all-MiniLM-L6-v2',
|
|
max_length=150, cache_capacity=100, persistent_cache_path='rag_cache.json'):
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
|
self.model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
|
|
|
|
self.embedding_model = SentenceTransformer(embedding_model, device=self.device)
|
|
self.max_length = max_length
|
|
self.cache = LRUCache(cache_capacity)
|
|
self.persistent_cache_path = persistent_cache_path
|
|
self.load_persistent_cache()
|
|
|
|
def load_persistent_cache(self):
|
|
if os.path.exists(self.persistent_cache_path):
|
|
with open(self.persistent_cache_path, 'r', encoding='utf-8') as f:
|
|
try:
|
|
persistent_data = json.load(f)
|
|
for key, value in persistent_data.items():
|
|
self.cache.put(key, value)
|
|
logger.info(f"Loaded persistent cache with {len(persistent_data)} entries.")
|
|
except json.JSONDecodeError:
|
|
logger.warning("Persistent cache file is corrupted. Initializing empty cache.")
|
|
else:
|
|
logger.info("No persistent cache found. Starting with empty cache.")
|
|
|
|
def save_persistent_cache(self):
|
|
with open(self.persistent_cache_path, 'w', encoding='utf-8') as f:
|
|
json.dump(self.cache.cache, f, indent=2)
|
|
logger.info(f"Saved persistent cache with {len(self.cache.cache)} entries.")
|
|
|
|
def save_rag_data(self, query, chunks, embeddings):
|
|
data = {
|
|
"query": query,
|
|
"chunks": chunks,
|
|
"embeddings": embeddings.tolist()
|
|
}
|
|
|
|
os.makedirs("rag_data", exist_ok=True)
|
|
|
|
filename = f"rag_data/{hash(query)}.json"
|
|
with open(filename, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
|
|
logger.info(f"Saved RAG data to {filename}")
|
|
|
|
def split_into_chunks(self, text, chunk_size=200):
|
|
words = text.split()
|
|
return [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
|
|
|
|
def retrieve_relevant_chunks(self, query, chunks, embeddings, top_k=3):
|
|
if embeddings.size(0) == 0:
|
|
logger.warning("Embeddings are empty. Cannot retrieve relevant chunks.")
|
|
return []
|
|
query_embedding = self.embedding_model.encode([query], convert_to_tensor=True)
|
|
cosine_scores = cosine_similarity(query_embedding.cpu().numpy(), embeddings.cpu().numpy())[0]
|
|
top_indices = cosine_scores.argsort()[-top_k:][::-1]
|
|
|
|
top_indices = [idx for idx in top_indices if idx < len(chunks)]
|
|
return [chunks[i] for i in top_indices]
|
|
|
|
def get_embeddings(self, chunks):
|
|
|
|
batch_size = 32
|
|
embeddings = []
|
|
for i in range(0, len(chunks), batch_size):
|
|
batch = chunks[i:i+batch_size]
|
|
batch_embeddings = self.embedding_model.encode(batch, convert_to_tensor=True)
|
|
embeddings.append(batch_embeddings)
|
|
if embeddings:
|
|
return torch.cat(embeddings, dim=0)
|
|
else:
|
|
return torch.tensor([])
|
|
|
|
def generate_summary(self, query, relevant_chunks):
|
|
cache_key = hashlib.md5((query + ''.join(relevant_chunks)).encode()).hexdigest()
|
|
cached_summary = self.cache.get(cache_key)
|
|
if cached_summary:
|
|
return cached_summary
|
|
|
|
context = " ".join(relevant_chunks)
|
|
prompt = f"Summarize the following content in relation to '{query}': {context}\n\nSummary:"
|
|
|
|
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
|
|
|
try:
|
|
output = self.model.generate(
|
|
input_ids,
|
|
max_length=input_ids.shape[1] + self.max_length,
|
|
num_return_sequences=1,
|
|
no_repeat_ngram_size=2,
|
|
top_k=50,
|
|
top_p=0.95,
|
|
temperature=0.7,
|
|
early_stopping=True
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error during summary generation: {str(e)}")
|
|
return "Summary generation failed."
|
|
|
|
self.save_rag_data(query, relevant_chunks, self.get_embeddings(relevant_chunks))
|
|
|
|
summary = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
|
summary = summary.split("Summary:")[-1].strip()
|
|
|
|
self.cache.put(cache_key, summary)
|
|
self.save_persistent_cache()
|
|
|
|
return summary
|
|
|
|
|
|
|
|
|
|
|
|
class WorldModel(nn.Module):
|
|
def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.3):
|
|
super(WorldModel, self).__init__()
|
|
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers,
|
|
batch_first=True, dropout=dropout)
|
|
self.fc = nn.Linear(hidden_size, output_size)
|
|
self.value_head = nn.Linear(hidden_size, 1)
|
|
self.layer_norm = nn.LayerNorm(hidden_size)
|
|
|
|
def forward(self, x, hidden=None):
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(1)
|
|
out, hidden = self.lstm(x, hidden)
|
|
last_output = out[:, -1, :]
|
|
last_output = self.layer_norm(last_output)
|
|
action_scores = self.fc(last_output)
|
|
state_value = self.value_head(last_output)
|
|
return action_scores, state_value, hidden
|
|
|
|
|
|
|
|
|
|
|
|
class Manager:
|
|
def __init__(self, state_size, num_options, hidden_size=128, learning_rate=0.001, gamma=0.99,
|
|
epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01, memory_capacity=1000, device=torch.device("cpu")):
|
|
self.state_size = state_size
|
|
self.num_options = num_options
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon
|
|
self.epsilon_decay = epsilon_decay
|
|
self.epsilon_min = epsilon_min
|
|
self.device = device
|
|
|
|
self.model = ManagerModel(state_size, hidden_size, num_options).to(self.device)
|
|
self.target_model = ManagerModel(state_size, hidden_size, num_options).to(self.device)
|
|
self.optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-5)
|
|
self.loss_fn = nn.MSELoss()
|
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=5, factor=0.5, verbose=True)
|
|
|
|
self.memory = PrioritizedReplayMemory(capacity=memory_capacity, alpha=0.6)
|
|
|
|
self.update_target_model()
|
|
|
|
def update_target_model(self):
|
|
self.target_model.load_state_dict(self.model.state_dict())
|
|
|
|
def remember(self, state, option, reward, next_state, done, td_error):
|
|
sample = (state, option, reward, next_state, done)
|
|
self.memory.add(td_error, sample)
|
|
|
|
def act(self, state):
|
|
if random.random() < self.epsilon:
|
|
option = random.randint(0, self.num_options - 1)
|
|
return option
|
|
state = torch.FloatTensor(state).unsqueeze(0).to(self.model.lstm.weight.device)
|
|
with torch.no_grad():
|
|
option_scores, _ = self.model(state)
|
|
option = torch.argmax(option_scores).item()
|
|
return option
|
|
|
|
def replay(self, batch_size, beta=0.4):
|
|
if self.memory.tree.n_entries < batch_size:
|
|
return
|
|
batch, idxs, weights = self.memory.sample(batch_size, beta)
|
|
states, options, rewards, next_states, dones = zip(*batch)
|
|
|
|
states = torch.FloatTensor(states).to(self.model.lstm.weight.device)
|
|
next_states = torch.FloatTensor(next_states).to(self.model.lstm.weight.device)
|
|
options = torch.LongTensor(options).unsqueeze(1).to(self.model.lstm.weight.device)
|
|
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.model.lstm.weight.device)
|
|
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.model.lstm.weight.device)
|
|
weights = torch.FloatTensor(weights).unsqueeze(1).to(self.model.lstm.weight.device)
|
|
|
|
|
|
current_q_values, _ = self.model(states)
|
|
current_q_values = current_q_values.gather(1, options)
|
|
|
|
|
|
with torch.no_grad():
|
|
next_q_values, _ = self.target_model(next_states)
|
|
max_next_q_values = next_q_values.max(1)[0].unsqueeze(1)
|
|
target_q_values = rewards + (self.gamma * max_next_q_values * (1 - dones))
|
|
|
|
|
|
td_errors = target_q_values - current_q_values
|
|
|
|
|
|
loss = (td_errors.pow(2) * weights).mean()
|
|
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
self.optimizer.step()
|
|
self.scheduler.step(loss.item())
|
|
|
|
|
|
td_errors_np = td_errors.detach().cpu().numpy().squeeze()
|
|
for idx, td_error in zip(idxs, td_errors_np):
|
|
self.memory.update(idx, np.abs(td_error))
|
|
|
|
|
|
if self.epsilon > self.epsilon_min:
|
|
self.epsilon *= self.epsilon_decay
|
|
|
|
|
|
|
|
|
|
|
|
def truncate_text(text, max_length=1024):
|
|
tokens = text.split()
|
|
if len(tokens) > max_length:
|
|
return ' '.join(tokens[:max_length])
|
|
return text
|
|
|
|
class AutonomousWebAgent:
|
|
def __init__(self, state_size, action_size, num_options, hidden_size=64, learning_rate=0.001,
|
|
gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01,
|
|
knowledge_base_path='knowledge_base.json'):
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.num_options = num_options
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon
|
|
self.epsilon_decay = epsilon_decay
|
|
self.epsilon_min = epsilon_min
|
|
|
|
|
|
self.summarizer = RAGSummarizer()
|
|
self.device = self.summarizer.device
|
|
|
|
|
|
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
|
|
|
|
|
|
self.worker_model = WorldModel(state_size, hidden_size, action_size).to(self.device)
|
|
self.worker_target_model = WorldModel(state_size, hidden_size, action_size).to(self.device)
|
|
self.worker_optimizer = optim.AdamW(self.worker_model.parameters(), lr=learning_rate, weight_decay=1e-5)
|
|
self.worker_loss_fn = nn.MSELoss()
|
|
self.worker_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.worker_optimizer, 'min', patience=5, factor=0.5, verbose=True)
|
|
self.worker_memory = PrioritizedReplayMemory(capacity=2000, alpha=0.6)
|
|
self.update_worker_target_model()
|
|
|
|
|
|
self.manager = Manager(state_size, num_options, hidden_size=128, learning_rate=learning_rate,
|
|
gamma=gamma, epsilon=epsilon, epsilon_decay=epsilon_decay,
|
|
epsilon_min=epsilon_min, memory_capacity=1000, device=self.device)
|
|
|
|
self.mcts = MCTS(initial_state="")
|
|
logger.info(f"Initialized AutonomousWebAgent with state_size={state_size}, action_size={action_size}, num_options={num_options}")
|
|
|
|
self.site_performance = {}
|
|
|
|
|
|
self.all_search_sites = [
|
|
"https://en.wikibooks.org/w/index.php?search=",
|
|
"https://en.wikiversity.org/w/index.php?search=",
|
|
"https://commons.wikimedia.org/w/index.php?search=",
|
|
"https://stackexchange.com/search?q=",
|
|
"https://arxiv.org/search/?query=",
|
|
"https://www.ncbi.nlm.nih.gov/pmc/?term=",
|
|
"https://www.gutenberg.org/ebooks/search/?query=",
|
|
"https://openlibrary.org/search?q=",
|
|
"https://doaj.org/search/articles?ref=homepage&q=",
|
|
"https://www.ted.com/search?q=",
|
|
"https://en.citizendium.org/wiki?search=",
|
|
"https://www.jstor.org/action/doBasicSearch?Query=",
|
|
"https://archive.org/search.php?query=",
|
|
"https://search.scielo.org/?q=",
|
|
"https://paperswithcode.com/search?q=",
|
|
"https://www.reddit.com/search/?q=",
|
|
"https://huggingface.co./models?search=",
|
|
"https://huggingface.co./datasets?search=",
|
|
"https://machinelearningmastery.com/?s=",
|
|
"https://www.kaggle.com/search?q=",
|
|
"https://towardsdatascience.com/search?q=",
|
|
"https://github.com/search?q=",
|
|
"https://stackoverflow.com/search?q=",
|
|
"https://www.youtube.com/results?search_query=",
|
|
"https://www.slideshare.net/search/slideshow?searchfrom=header&q="
|
|
]
|
|
|
|
|
|
self.knowledge_base_path = knowledge_base_path
|
|
self.knowledge_base = []
|
|
self.kb_embeddings = None
|
|
self.load_knowledge_base()
|
|
|
|
|
|
self.additional_features = ['image_count', 'script_count', 'css_count']
|
|
|
|
def save(self, filename):
|
|
"""Save the entire agent state."""
|
|
state = {
|
|
'worker_model': self.worker_model.state_dict(),
|
|
'manager_model': self.manager.model.state_dict(),
|
|
'worker_optimizer': self.worker_optimizer.state_dict(),
|
|
'manager_optimizer': self.manager.optimizer.state_dict(),
|
|
'epsilon': self.epsilon
|
|
}
|
|
torch.save(state, filename)
|
|
logger.info(f"Saved agent state to {filename}")
|
|
|
|
def load(self, filename):
|
|
"""Load the entire agent state."""
|
|
state = torch.load(filename, map_location=self.device)
|
|
self.worker_model.load_state_dict(state['worker_model'])
|
|
self.manager.model.load_state_dict(state['manager_model'])
|
|
self.worker_optimizer.load_state_dict(state['worker_optimizer'])
|
|
self.manager.optimizer.load_state_dict(state['manager_optimizer'])
|
|
self.epsilon = state['epsilon']
|
|
logger.info(f"Loaded agent state from {filename}")
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text(self, prompt):
|
|
|
|
chunks = self.summarizer.split_into_chunks(prompt)
|
|
embeddings = self.summarizer.get_embeddings(chunks)
|
|
relevant_chunks = self.summarizer.retrieve_relevant_chunks(query=prompt, chunks=chunks, embeddings=embeddings)
|
|
generated_text = self.summarizer.generate_summary(prompt, relevant_chunks)
|
|
return generated_text
|
|
|
|
|
|
|
|
|
|
|
|
def load_knowledge_base(self):
|
|
if not os.path.exists(self.knowledge_base_path):
|
|
logger.warning(f"Knowledge base file {self.knowledge_base_path} does not exist. Initializing empty KB.")
|
|
self.knowledge_base = []
|
|
self.kb_embeddings = torch.tensor([]).to(self.device)
|
|
return
|
|
|
|
with open(self.knowledge_base_path, 'r', encoding='utf-8') as f:
|
|
self.knowledge_base = json.load(f)
|
|
|
|
if self.knowledge_base:
|
|
texts = [doc['content'] for doc in self.knowledge_base]
|
|
self.kb_embeddings = self.embedding_model.encode(texts, convert_to_tensor=True)
|
|
logger.info(f"Loaded {len(self.knowledge_base)} documents into the knowledge base.")
|
|
else:
|
|
self.kb_embeddings = torch.tensor([]).to(self.device)
|
|
logger.info("Knowledge base is empty.")
|
|
|
|
def save_knowledge_base(self):
|
|
with open(self.knowledge_base_path, 'w', encoding='utf-8') as f:
|
|
json.dump(self.knowledge_base, f, indent=2)
|
|
logger.info(f"Knowledge base saved with {len(self.knowledge_base)} documents.")
|
|
|
|
def add_document_to_kb(self, title, content, metadata=None):
|
|
document = {
|
|
"title": title,
|
|
"content": content,
|
|
"metadata": metadata or {}
|
|
}
|
|
self.knowledge_base.append(document)
|
|
|
|
new_embedding = self.embedding_model.encode([content], convert_to_tensor=True).to(self.device)
|
|
if self.kb_embeddings.numel() == 0:
|
|
self.kb_embeddings = new_embedding
|
|
else:
|
|
self.kb_embeddings = torch.cat([self.kb_embeddings, new_embedding], dim=0)
|
|
|
|
self.save_knowledge_base()
|
|
logger.info(f"Added new document to knowledge base: {title}")
|
|
|
|
def retrieve_from_kb(self, query, top_k=5):
|
|
if not self.knowledge_base:
|
|
logger.warning("Knowledge base is empty. No documents to retrieve.")
|
|
return []
|
|
|
|
query_embedding = self.embedding_model.encode([query], convert_to_tensor=True).to(self.device)
|
|
|
|
if self.kb_embeddings is None or self.kb_embeddings.numel() == 0:
|
|
logger.warning("Knowledge base embeddings are empty. No documents to retrieve.")
|
|
return []
|
|
|
|
if query_embedding.size(1) != self.kb_embeddings.size(1):
|
|
logger.error("Dimension mismatch between query embedding and KB embeddings.")
|
|
return []
|
|
|
|
cosine_scores = cosine_similarity(query_embedding.cpu().numpy(), self.kb_embeddings.cpu().numpy())[0]
|
|
top_indices = cosine_scores.argsort()[-top_k:][::-1]
|
|
|
|
|
|
top_indices = [idx for idx in top_indices if idx < len(self.knowledge_base)]
|
|
|
|
retrieved_docs = []
|
|
for idx in top_indices:
|
|
doc = self.knowledge_base[idx]
|
|
doc['score'] = cosine_scores[idx]
|
|
retrieved_docs.append(doc)
|
|
|
|
logger.info(f"Retrieved top {len(retrieved_docs)} documents from Knowledge Base for the query.")
|
|
return retrieved_docs
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_from_web(self, query, top_k=5):
|
|
logger.info(f"Performing web search for query: {query}")
|
|
mcts_iterations = self.calculate_mcts_iterations(np.zeros(self.state_size, dtype=np.float32))
|
|
self.mcts = MCTS(initial_state=query, num_simulations=mcts_iterations)
|
|
|
|
try:
|
|
new_query = yield self.mcts.run()
|
|
logger.debug(f"New query from MCTS: {new_query}")
|
|
|
|
search_sites = self.select_search_sites(new_query)
|
|
results = yield self.mcts.web_search(new_query, search_sites)
|
|
logger.debug(f"Web search completed. Found {len(results)} results")
|
|
return results[:top_k] if results else []
|
|
except Exception as e:
|
|
logger.error(f"Error during MCTS or web search: {str(e)}", exc_info=True)
|
|
return []
|
|
|
|
def combine_documents(self, kb_docs, web_docs):
|
|
combined = kb_docs + web_docs
|
|
logger.info(f"Combined {len(kb_docs)} KB documents and {len(web_docs)} Web documents.")
|
|
return combined
|
|
|
|
def save_llm_training_data(self, query, content, summary=None, link=None, title=None):
|
|
data = {
|
|
"query": query,
|
|
"search_result": {
|
|
"link": link,
|
|
"title": title
|
|
},
|
|
"content": content,
|
|
"description": summary
|
|
}
|
|
|
|
os.makedirs("llm_training_data", exist_ok=True)
|
|
file_path = "llm_training_data/llm_training_data.jsonl"
|
|
|
|
|
|
with open(file_path, 'a', encoding='utf-8') as f:
|
|
json.dump(data, f)
|
|
f.write('\n')
|
|
|
|
logger.info(f"Appended LLM training data to {file_path}")
|
|
|
|
|
|
|
|
|
|
|
|
def remember_manager(self, state, option, reward, next_state, done, td_error):
|
|
self.manager.remember(state, option, reward, next_state, done, td_error)
|
|
|
|
def remember_worker(self, state, action, reward, next_state, done):
|
|
self.worker_memory.add(reward, (state, action, reward, next_state, done))
|
|
|
|
|
|
|
|
|
|
|
|
def act_manager(self, state):
|
|
option = self.manager.act(state)
|
|
return option
|
|
|
|
def act_worker(self, state):
|
|
action = self.worker_model.act(state, epsilon=self.epsilon)
|
|
return action
|
|
|
|
|
|
|
|
|
|
|
|
def replay_manager(self, batch_size=32, beta=0.4):
|
|
self.manager.replay(batch_size, beta)
|
|
|
|
def replay_worker(self, batch_size=32, beta=0.4):
|
|
result = self.worker_memory.replay(batch_size, beta)
|
|
if result is None:
|
|
return
|
|
batch, idxs, weights = result
|
|
if len(self.worker_memory.tree.data) >= batch_size:
|
|
batch, idxs, weights = self.worker_memory.sample(batch_size, beta)
|
|
states, actions, rewards, next_states, dones = zip(*batch)
|
|
|
|
states = torch.FloatTensor(states).to(self.worker_model.lstm.weight.device)
|
|
next_states = torch.FloatTensor(next_states).to(self.worker_model.lstm.weight.device)
|
|
actions = torch.LongTensor(actions).unsqueeze(1).to(self.worker_model.lstm.weight.device)
|
|
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.worker_model.lstm.weight.device)
|
|
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.worker_model.lstm.weight.device)
|
|
weights = torch.FloatTensor(weights).unsqueeze(1).to(self.worker_model.lstm.weight.device)
|
|
|
|
|
|
current_q_values, _ = self.worker_model(states)
|
|
current_q_values = current_q_values.gather(1, actions)
|
|
|
|
|
|
with torch.no_grad():
|
|
next_q_values, _ = self.worker_target_model(next_states)
|
|
max_next_q_values = next_q_values.max(1)[0].unsqueeze(1)
|
|
target_q_values = rewards + (self.gamma * max_next_q_values * (1 - dones))
|
|
|
|
|
|
td_errors = target_q_values - current_q_values
|
|
|
|
|
|
loss = (td_errors.pow(2) * weights).mean()
|
|
|
|
|
|
self.worker_optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.worker_model.parameters(), max_norm=1.0)
|
|
self.worker_optimizer.step()
|
|
self.worker_scheduler.step(loss.item())
|
|
|
|
|
|
td_errors_np = td_errors.detach().cpu().numpy().squeeze()
|
|
for idx, td_error in zip(idxs, td_errors_np):
|
|
self.worker_memory.update(idx, np.abs(td_error))
|
|
|
|
|
|
if self.epsilon > self.epsilon_min:
|
|
self.epsilon *= self.epsilon_decay
|
|
logger.debug(f"Updated epsilon to: {self.epsilon}")
|
|
|
|
|
|
|
|
|
|
|
|
def load_worker_model(self, name):
|
|
self.worker_model.load_state_dict(torch.load(name, map_location=self.device))
|
|
logger.info(f"Loaded worker model weights from {name}")
|
|
|
|
def save_worker_model(self, name):
|
|
torch.save(self.worker_model.state_dict(), name)
|
|
logger.info(f"Saved worker model weights to {name}")
|
|
|
|
def load_manager_model(self, name):
|
|
self.manager.model.load_state_dict(torch.load(name, map_location=self.device))
|
|
self.manager.update_target_model()
|
|
logger.info(f"Loaded manager model weights from {name}")
|
|
|
|
def save_manager_model(self, name):
|
|
torch.save(self.manager.model.state_dict(), name)
|
|
logger.info(f"Saved manager model weights to {name}")
|
|
|
|
|
|
|
|
|
|
|
|
def update_worker_target_model(self):
|
|
self.worker_target_model.load_state_dict(self.worker_model.state_dict())
|
|
logger.info("Updated worker target model with current model weights")
|
|
|
|
def update_manager_target_model(self):
|
|
self.manager.update_target_model()
|
|
logger.info("Updated manager target model with current model weights")
|
|
|
|
|
|
|
|
|
|
|
|
def extract_features(self, content, query):
|
|
content = truncate_text(content)
|
|
query = truncate_text(query)
|
|
soup = BeautifulSoup(content, 'html.parser')
|
|
text = soup.get_text()
|
|
word_count = len(text.split())
|
|
link_count = len(soup.find_all('a'))
|
|
header_count = len(soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']))
|
|
|
|
|
|
text_embedding = self.embedding_model.encode([text], convert_to_tensor=True).to(self.device)
|
|
query_embedding = self.embedding_model.encode([query], convert_to_tensor=True).to(self.device)
|
|
semantic_similarity = cosine_similarity(text_embedding.cpu().numpy(), query_embedding.cpu().numpy())[0][0]
|
|
|
|
|
|
image_count = len(soup.find_all('img'))
|
|
script_count = len(soup.find_all('script'))
|
|
css_count = len(soup.find_all('link', rel='stylesheet'))
|
|
|
|
return np.array([word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count])
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_reward(self, content, query):
|
|
try:
|
|
ranked_results = train_ranking_model(query, [{'content': content}])
|
|
logger.debug(f"Ranked results: {ranked_results}")
|
|
if ranked_results and isinstance(ranked_results[0], dict) and 'predicted_score' in ranked_results[0]:
|
|
reward = max(ranked_results[0]['predicted_score'], 0)
|
|
logger.debug(f"Calculated reward: {reward}")
|
|
return reward
|
|
else:
|
|
logger.warning(f"Invalid ranked results: {ranked_results}")
|
|
return 0
|
|
except Exception as e:
|
|
logger.error(f"Error in calculate_reward: {str(e)}", exc_info=True)
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
def select_search_sites(self, query, num_sites=5):
|
|
|
|
site_scores = {}
|
|
for (site, q), score in self.site_performance.items():
|
|
if q == query:
|
|
site_scores[site] = site_scores.get(site, 0) + score
|
|
if site_scores:
|
|
sorted_sites = sorted(site_scores.items(), key=lambda x: x[1], reverse=True)
|
|
top_sites = [site for site, score in sorted_sites[:num_sites]]
|
|
else:
|
|
|
|
top_sites = random.sample(self.all_search_sites, num_sites)
|
|
|
|
search_sites = [site + query for site in top_sites]
|
|
return search_sites
|
|
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
def search(self, query, max_steps=2):
|
|
logger.info(f"Starting search for query: {query}")
|
|
state = np.zeros(self.state_size, dtype=np.float32)
|
|
total_reward = 0
|
|
content = ""
|
|
done = False
|
|
results = None
|
|
|
|
try:
|
|
|
|
option = self.act_manager(state)
|
|
logger.debug(f"Manager selected option: {option}")
|
|
|
|
|
|
if option == 0:
|
|
logger.debug("Executing Search Option")
|
|
results = yield self.retrieve_from_web(query)
|
|
if results:
|
|
content = results[0]['content']
|
|
site = urlparse(results[0]['link']).netloc
|
|
self.save_llm_training_data(
|
|
query,
|
|
content,
|
|
summary=results[0].get('summary'),
|
|
link=results[0].get('link'),
|
|
title=results[0].get('title')
|
|
)
|
|
self.add_document_to_kb(title=results[0].get('title', 'No Title'), content=content, metadata=results[0].get('meta', {}))
|
|
next_state = self.extract_features(content, query)
|
|
reward = self.calculate_reward(content, query)
|
|
logger.debug(f"Extracted features: {next_state}, Reward: {reward}")
|
|
|
|
key = (site, query)
|
|
self.site_performance[key] = self.site_performance.get(key, 0) + reward
|
|
|
|
|
|
self.remember_manager(state, option, reward, next_state, done, td_error=reward)
|
|
|
|
|
|
self.remember_worker(state, 0, reward, next_state, done)
|
|
|
|
state = next_state.astype(np.float32)
|
|
total_reward += reward
|
|
|
|
else:
|
|
reward = -1
|
|
logger.warning(f"No results for query: {query}")
|
|
|
|
self.remember_manager(state, option, reward, state, True, td_error=reward)
|
|
|
|
elif option == 1:
|
|
logger.debug("Executing Summarize Option")
|
|
if content:
|
|
summary = self.summarizer.generate_summary(content, query)
|
|
self.save_llm_training_data(
|
|
query,
|
|
content,
|
|
summary=summary,
|
|
link=results[0].get('link') if results else None,
|
|
title=results[0].get('title') if results else None
|
|
)
|
|
reward = self.calculate_reward(summary, query)
|
|
next_state = self.extract_features(summary, query)
|
|
logger.info(f"Summary:\n{summary}")
|
|
logger.info(f"Summarized content. Reward: {reward}")
|
|
|
|
|
|
self.remember_manager(state, option, reward, next_state, done, td_error=reward)
|
|
|
|
|
|
self.remember_worker(state, 1, reward, next_state, done)
|
|
|
|
state = next_state.astype(np.float32)
|
|
total_reward += reward
|
|
else:
|
|
reward = -1
|
|
logger.warning("No content to summarize")
|
|
|
|
self.remember_manager(state, option, reward, state, True, td_error=reward)
|
|
|
|
elif option == 2:
|
|
logger.debug("Executing RAG-based Generation Option")
|
|
kb_docs = self.retrieve_from_kb(query, top_k=5)
|
|
web_docs = []
|
|
combined_docs = self.combine_documents(kb_docs, web_docs)
|
|
generated_output = self.generate_rag_response(query, combined_docs)
|
|
logger.info(f"Generated Output:\n{generated_output}")
|
|
self.save_llm_training_data(
|
|
query,
|
|
generated_output,
|
|
summary=None,
|
|
link=None,
|
|
title="RAG-generated response"
|
|
)
|
|
reward = self.calculate_reward(generated_output, query)
|
|
next_state = self.extract_features(generated_output, query)
|
|
|
|
|
|
self.remember_manager(state, option, reward, next_state, done, td_error=reward)
|
|
|
|
|
|
self.remember_worker(state, 2, reward, next_state, done)
|
|
|
|
state = next_state.astype(np.float32)
|
|
total_reward += reward
|
|
|
|
else:
|
|
logger.warning(f"Unknown option selected by Manager: {option}")
|
|
|
|
|
|
self.replay_manager(batch_size=32, beta=0.4)
|
|
self.replay_worker(batch_size=32, beta=0.4)
|
|
|
|
|
|
self.update_worker_target_model()
|
|
self.update_manager_target_model()
|
|
|
|
logger.info(f"Search completed. Total reward: {total_reward}")
|
|
defer.returnValue(total_reward)
|
|
except Exception as e:
|
|
logger.error(f"Error during search: {str(e)}", exc_info=True)
|
|
defer.returnValue(-1)
|
|
|
|
|
|
|
|
|
|
|
|
def summarize(self, content, query):
|
|
chunks = self.summarizer.split_into_chunks(content)
|
|
embeddings = self.summarizer.get_embeddings(chunks)
|
|
relevant_chunks = self.summarizer.retrieve_relevant_chunks(query, chunks, embeddings)
|
|
summary = self.summarizer.generate_summary(query, relevant_chunks)
|
|
|
|
|
|
self.summarizer.save_rag_data(query, chunks, embeddings)
|
|
|
|
return summary
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_mcts_iterations(self, state):
|
|
|
|
base_iterations = 2
|
|
complexity_factor = np.mean(state) / 100
|
|
iterations = int(base_iterations * (1 + complexity_factor))
|
|
max_iterations = 5
|
|
return min(iterations, max_iterations)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_rag_response(self, query, combined_docs):
|
|
if not combined_docs:
|
|
logger.warning("No documents available for RAG-based generation.")
|
|
return "I'm sorry, I couldn't find any relevant information."
|
|
|
|
|
|
context = "\n\n".join([f"Title: {doc.get('title', 'No Title')}\nContent: {doc.get('content', '')}" for doc in combined_docs])
|
|
prompt = f"Query: {query}\n\nContext:\n{context}\n\nAnswer:"
|
|
|
|
|
|
cache_key = hashlib.md5(prompt.encode()).hexdigest()
|
|
cached_response = self.summarizer.cache.get(cache_key)
|
|
if cached_response:
|
|
logger.debug("Using cached RAG response.")
|
|
return cached_response
|
|
|
|
|
|
input_ids = self.summarizer.tokenizer.encode(prompt, return_tensors='pt').to(self.summarizer.device)
|
|
try:
|
|
output = self.summarizer.model.generate(
|
|
input_ids,
|
|
max_length=input_ids.shape[1] + self.summarizer.max_length,
|
|
num_return_sequences=1,
|
|
no_repeat_ngram_size=2,
|
|
top_k=50,
|
|
top_p=0.95,
|
|
temperature=0.7,
|
|
early_stopping=True
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error during RAG response generation: {str(e)}")
|
|
return "RAG response generation failed."
|
|
|
|
response = self.summarizer.tokenizer.decode(output[0], skip_special_tokens=True)
|
|
answer = response.split("Answer:")[-1].strip()
|
|
|
|
|
|
self.summarizer.cache.put(cache_key, answer)
|
|
self.summarizer.save_persistent_cache()
|
|
return answer
|
|
|
|
|
|
|
|
|
|
|
|
def select_option(self, option):
|
|
|
|
"""
|
|
Define the mapping of options to their corresponding actions.
|
|
"""
|
|
|
|
option_actions = {
|
|
0: self.perform_search,
|
|
1: self.perform_summarization,
|
|
2: self.perform_rag_generation
|
|
}
|
|
action = option_actions.get(option, None)
|
|
if action:
|
|
return action
|
|
else:
|
|
logger.error(f"No action defined for option: {option}")
|
|
return None
|
|
|
|
def perform_search(self, query):
|
|
"""
|
|
Perform the search action.
|
|
"""
|
|
|
|
pass
|
|
|
|
def perform_summarization(self, content, query):
|
|
"""
|
|
Perform the summarization action.
|
|
"""
|
|
|
|
pass
|
|
|
|
def perform_rag_generation(self, query, combined_docs):
|
|
"""
|
|
Perform the RAG-based generation action.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class LRUCache:
|
|
def __init__(self, capacity):
|
|
self.cache = OrderedDict()
|
|
self.capacity = capacity
|
|
|
|
def get(self, key):
|
|
if key not in self.cache:
|
|
return None
|
|
self.cache.move_to_end(key)
|
|
return self.cache[key]
|
|
|
|
def put(self, key, value):
|
|
if key in self.cache:
|
|
self.cache.move_to_end(key)
|
|
self.cache[key] = value
|
|
if len(self.cache) > self.capacity:
|
|
self.cache.popitem(last=False)
|
|
|
|
|
|
|