|
import re |
|
import asyncio |
|
import warnings |
|
import logging |
|
|
|
import aiohttp |
|
import requests |
|
from bs4 import BeautifulSoup |
|
from langchain.retrievers.document_compressors import DocumentCompressorPipeline |
|
from langchain.retrievers.ensemble import EnsembleRetriever |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.retrievers.document_compressors.embeddings_filter import EmbeddingsFilter |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.schema import Document |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.document_transformers import EmbeddingsRedundantFilter |
|
from langchain_community.retrievers import BM25Retriever |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
import optimum.bettertransformer.transformation |
|
try: |
|
from qdrant_client import QdrantClient, models |
|
except ImportError: |
|
qrant_client = None |
|
|
|
from .qdrant_retriever import MyQdrantSparseVectorRetriever |
|
from .semantic_chunker import BoundedSemanticChunker |
|
|
|
|
|
class LangchainCompressor: |
|
|
|
def __init__(self, device="cuda", num_results: int = 5, similarity_threshold: float = 0.5, chunk_size: int = 500, |
|
ensemble_weighting: float = 0.5, splade_batch_size: int = 2, keyword_retriever: str = "bm25", |
|
model_cache_dir: str = None, chunking_method: str = "character-based", |
|
chunker_breakpoint_threshold_amount: int = 10): |
|
self.device = device |
|
self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2", model_kwargs={"device": device}, |
|
cache_folder=model_cache_dir) |
|
if keyword_retriever == "splade": |
|
if "QdrantClient" not in globals(): |
|
raise ImportError("Package qrant_client is missing. Please install it using 'pip install qdrant-client") |
|
self.splade_doc_tokenizer = AutoTokenizer.from_pretrained("naver/efficient-splade-VI-BT-large-doc", |
|
cache_dir=model_cache_dir) |
|
self.splade_doc_model = AutoModelForMaskedLM.from_pretrained("naver/efficient-splade-VI-BT-large-doc", |
|
cache_dir=model_cache_dir).to(self.device) |
|
self.splade_query_tokenizer = AutoTokenizer.from_pretrained("naver/efficient-splade-VI-BT-large-query", |
|
cache_dir=model_cache_dir) |
|
self.splade_query_model = AutoModelForMaskedLM.from_pretrained("naver/efficient-splade-VI-BT-large-query", |
|
cache_dir=model_cache_dir).to(self.device) |
|
optimum_logger = optimum.bettertransformer.transformation.logger |
|
original_log_level = optimum_logger.level |
|
|
|
optimum_logger.setLevel(logging.ERROR) |
|
self.splade_doc_model.to_bettertransformer() |
|
self.splade_query_model.to_bettertransformer() |
|
optimum_logger.setLevel(original_log_level) |
|
self.splade_batch_size = splade_batch_size |
|
|
|
self.spaces_regex = re.compile(r" {3,}") |
|
self.num_results = num_results |
|
self.similarity_threshold = similarity_threshold |
|
self.chunking_method = chunking_method |
|
self.chunk_size = chunk_size |
|
self.chunker_breakpoint_threshold_amount = chunker_breakpoint_threshold_amount |
|
self.ensemble_weighting = ensemble_weighting |
|
self.keyword_retriever = keyword_retriever |
|
|
|
def preprocess_text(self, text: str) -> str: |
|
text = text.replace("\n", " \n") |
|
text = self.spaces_regex.sub(" ", text) |
|
text = text.strip() |
|
return text |
|
|
|
def retrieve_documents(self, query: str, url_list: list[str]) -> list[Document]: |
|
yield "Downloading webpages..." |
|
html_url_tupls = zip(asyncio.run(async_fetch_urls(url_list)), url_list) |
|
html_url_tupls = [(content, url) for content, url in html_url_tupls if content is not None] |
|
if not html_url_tupls: |
|
return [] |
|
|
|
documents = [html_to_plaintext_doc(html, url) for html, url in html_url_tupls] |
|
if self.chunking_method == "semantic": |
|
text_splitter = BoundedSemanticChunker(self.embeddings, breakpoint_threshold_type="percentile", |
|
breakpoint_threshold_amount=self.chunker_breakpoint_threshold_amount, |
|
max_chunk_size=self.chunk_size) |
|
else: |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=10, |
|
separators=["\n\n", "\n", ".", ", ", " ", ""]) |
|
yield "Chunking page texts..." |
|
split_docs = text_splitter.split_documents(documents) |
|
yield "Retrieving relevant results..." |
|
|
|
faiss_retriever = FAISS.from_documents(split_docs, self.embeddings).as_retriever( |
|
search_kwargs={"k": self.num_results} |
|
) |
|
|
|
|
|
|
|
if self.keyword_retriever == "bm25": |
|
keyword_retriever = BM25Retriever.from_documents(split_docs, preprocess_func=self.preprocess_text) |
|
keyword_retriever.k = self.num_results |
|
elif self.keyword_retriever == "splade": |
|
client = QdrantClient(location=":memory:") |
|
collection_name = "sparse_collection" |
|
vector_name = "sparse_vector" |
|
|
|
client.create_collection( |
|
collection_name, |
|
vectors_config={}, |
|
sparse_vectors_config={ |
|
vector_name: models.SparseVectorParams( |
|
index=models.SparseIndexParams( |
|
on_disk=False, |
|
) |
|
) |
|
}, |
|
) |
|
|
|
keyword_retriever = MyQdrantSparseVectorRetriever( |
|
splade_doc_tokenizer=self.splade_doc_tokenizer, |
|
splade_doc_model=self.splade_doc_model, |
|
splade_query_tokenizer=self.splade_query_tokenizer, |
|
splade_query_model=self.splade_query_model, |
|
device=self.device, |
|
client=client, |
|
collection_name=collection_name, |
|
sparse_vector_name=vector_name, |
|
sparse_encoder=None, |
|
batch_size=self.splade_batch_size, |
|
k=self.num_results |
|
) |
|
keyword_retriever.add_documents(split_docs) |
|
else: |
|
raise ValueError("self.keyword_retriever must be one of ('bm25', 'splade')") |
|
|
|
redundant_filter = EmbeddingsRedundantFilter(embeddings=self.embeddings) |
|
embeddings_filter = EmbeddingsFilter(embeddings=self.embeddings, k=None, |
|
similarity_threshold=self.similarity_threshold) |
|
pipeline_compressor = DocumentCompressorPipeline( |
|
transformers=[redundant_filter, embeddings_filter] |
|
) |
|
|
|
compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, |
|
base_retriever=faiss_retriever) |
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[compression_retriever, keyword_retriever], |
|
weights=[self.ensemble_weighting, 1 - self.ensemble_weighting] |
|
) |
|
compressed_docs = ensemble_retriever.invoke(query) |
|
|
|
|
|
return compressed_docs[:self.num_results] |
|
|
|
|
|
async def async_download_html(url, headers): |
|
async with aiohttp.ClientSession(headers=headers, timeout=aiohttp.ClientTimeout(10)) as session: |
|
try: |
|
resp = await session.get(url) |
|
return await resp.text() |
|
except UnicodeDecodeError: |
|
print( |
|
f"LLM_Web_search | {url} generated an exception: Expected content type text/html. Got {resp.headers['Content-Type']}.") |
|
except TimeoutError as exc: |
|
print('LLM_Web_search | %r did not load in time' % url) |
|
except Exception as exc: |
|
print('LLM_Web_search | %r generated an exception: %s' % (url, exc)) |
|
return None |
|
|
|
|
|
async def async_fetch_urls(urls): |
|
headers = {"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:120.0) Gecko/20100101 Firefox/120.0", |
|
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", |
|
"Accept-Language": "en-US,en;q=0.5"} |
|
webpages = await asyncio.gather(*[(async_download_html(url, headers)) for url in urls]) |
|
return webpages |
|
|
|
|
|
def docs_to_pretty_str(docs) -> str: |
|
ret_str = "" |
|
for i, doc in enumerate(docs): |
|
ret_str += f"Result {i+1}:\n" |
|
ret_str += f"{doc.page_content}\n" |
|
ret_str += f"Source URL: {doc.metadata['source']}\n\n" |
|
return ret_str |
|
|
|
|
|
def download_html(url: str) -> bytes: |
|
headers = {"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:120.0) Gecko/20100101 Firefox/120.0", |
|
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", |
|
"Accept-Language": "en-US,en;q=0.5"} |
|
|
|
response = requests.get(url, headers=headers, verify=True, timeout=8) |
|
response.raise_for_status() |
|
|
|
content_type = response.headers.get("Content-Type", "") |
|
if not content_type.startswith("text/html"): |
|
raise ValueError(f"Expected content type text/html. Got {content_type}.") |
|
return response.content |
|
|
|
|
|
def html_to_plaintext_doc(html_text: str or bytes, url: str) -> Document: |
|
with warnings.catch_warnings(action="ignore"): |
|
soup = BeautifulSoup(html_text, features="lxml") |
|
for script in soup(["script", "style"]): |
|
script.extract() |
|
|
|
strings = '\n'.join([s.strip() for s in soup.stripped_strings]) |
|
webpage_document = Document(page_content=strings, metadata={"source": url}) |
|
return webpage_document |
|
|