cakiki's picture
Upload folder using huggingface_hub
aeb12b8
import json
from typing import List, Literal, Protocol, Tuple, TypedDict, Union
from pyserini.analysis import get_lucene_analyzer
from pyserini.index import IndexReader
from pyserini.search import DenseSearchResult, JLuceneSearcherResult
from pyserini.search.faiss.__main__ import init_query_encoder
from pyserini.search.faiss import FaissSearcher
from pyserini.search.hybrid import HybridSearcher
from pyserini.search.lucene import LuceneSearcher
EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"]
class AnalyzerArgs(TypedDict):
language: str
stemming: bool
stemmer: str
stopwords: bool
huggingFaceTokenizer: str
class SearchResult(TypedDict):
docid: str
text: str
score: float
language: str
class Searcher(Protocol):
def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]:
...
def init_searcher_and_reader(
sparse_index_path: str = None,
bm25_k1: float = None,
bm25_b: float = None,
analyzer_args: AnalyzerArgs = None,
dense_index_path: str = None,
encoder_name_or_path: str = None,
encoder_class: EncoderClass = None,
tokenizer_name: str = None,
device: str = None,
prefix: str = None
) -> Tuple[Union[FaissSearcher, HybridSearcher, LuceneSearcher], IndexReader]:
"""
Initialize and return an approapriate searcher
Parameters
----------
sparse_index_path: str
Path to sparse index
dense_index_path: str
Path to dense index
encoder_name_or_path: str
Path to query encoder checkpoint or encoder name
encoder_class: str
Query encoder class to use. If None, infer from `encoder`
tokenizer_name: str
Tokenizer name or path
device: str
Device to load Query encoder on.
prefix: str
Query prefix if exists
Returns
-------
Searcher: FaissSearcher | HybridSearcher | LuceneSearcher
A sparse, dense or hybrid searcher
"""
reader = None
if sparse_index_path:
ssearcher = LuceneSearcher(sparse_index_path)
if analyzer_args:
analyzer = get_lucene_analyzer(**analyzer_args)
ssearcher.set_analyzer(analyzer)
if bm25_k1 and bm25_b:
ssearcher.set_bm25(bm25_k1, bm25_b)
if dense_index_path:
encoder = init_query_encoder(
encoder=encoder_name_or_path,
encoder_class=encoder_class,
tokenizer_name=tokenizer_name,
topics_name=None,
encoded_queries=None,
device=device,
prefix=prefix
)
reader = IndexReader(sparse_index_path)
dsearcher = FaissSearcher(dense_index_path, encoder)
if sparse_index_path:
hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher)
return hsearcher, reader
else:
return dsearcher, reader
return ssearcher, reader
def _search(searcher: Searcher, reader: IndexReader, query: str, num_results: int = 10) -> List[SearchResult]:
"""
Parameters:
-----------
searcher: FaissSearcher | HybridSearcher | LuceneSearcher
A sparse, dense or hybrid searcher
query: str
Query for which to retrieve results
num_results: int
Maximum number of results to retrieve
Returns:
--------
Dict:
"""
def _get_dict(r: Union[DenseSearchResult, JLuceneSearcherResult]):
if isinstance(r, JLuceneSearcherResult):
return json.loads(r.raw)
elif isinstance(r, DenseSearchResult):
# Get document from sparse_index using index reader
return json.loads(reader.doc(r.docid).raw())
search_results = searcher.search(query, k=num_results)
all_results = [
SearchResult(
docid=result["id"],
text=result["contents"],
score=search_results[idx].score
) for idx, result in enumerate(map(lambda r: _get_dict(r), search_results))
]
return all_results