Spaces:
Running
Running
File size: 2,091 Bytes
dd87c4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# retriever.py
from langchain.schema import BaseRetriever
from typing import List
from pydantic import BaseModel
class CombinedRetriever(BaseRetriever):
"""
A retriever that combines multiple retrievers and returns the top K relevant documents.
"""
retrievers: List[BaseRetriever]
k: int = 5
def _get_relevant_documents(self, query: str):
"""
Retrieve relevant documents by querying all combined retrievers.
Args:
query: The search query string.
Returns:
A list of relevant documents.
"""
all_docs = []
for retriever in self.retrievers:
# Correctly invoke the retriever with the query string
docs = retriever.get_relevant_documents(query)
all_docs.extend(docs)
# Return the top K documents
return all_docs[:self.k]
async def _aget_relevant_documents(self, query: str):
"""
Asynchronously retrieve relevant documents by querying all combined retrievers.
Args:
query: The search query string.
Returns:
A list of relevant documents.
"""
all_docs = []
for retriever in self.retrievers:
# Correctly invoke the retriever with the query string
docs = await retriever.aget_relevant_documents(query)
all_docs.extend(docs)
# Return the top K documents
return all_docs[:self.k]
def create_combined_retriever(vector_stores, search_kwargs={"k": 3}):
"""
Create a CombinedRetriever from multiple vector stores.
Args:
vector_stores: A dictionary of vector stores.
search_kwargs: Keyword arguments for the retrievers (e.g., number of documents).
Returns:
An instance of CombinedRetriever.
"""
retrievers = [
vs.as_retriever(search_kwargs=search_kwargs)
for vs in vector_stores.values()
]
combined_retriever = CombinedRetriever(
retrievers=retrievers,
k=search_kwargs.get("k", 3)
)
return combined_retriever
|