Spaces:
Runtime error
Runtime error
import json | |
from typing import List, Literal, Protocol, Tuple, TypedDict, Union | |
from pyserini.analysis import get_lucene_analyzer | |
from pyserini.index import IndexReader | |
from pyserini.search import DenseSearchResult, JLuceneSearcherResult | |
from pyserini.search.faiss.__main__ import init_query_encoder | |
from pyserini.search.faiss import FaissSearcher | |
from pyserini.search.hybrid import HybridSearcher | |
from pyserini.search.lucene import LuceneSearcher | |
EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"] | |
class AnalyzerArgs(TypedDict): | |
language: str | |
stemming: bool | |
stemmer: str | |
stopwords: bool | |
huggingFaceTokenizer: str | |
class SearchResult(TypedDict): | |
docid: str | |
text: str | |
score: float | |
language: str | |
class Searcher(Protocol): | |
def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]: | |
... | |
def init_searcher_and_reader( | |
sparse_index_path: str = None, | |
bm25_k1: float = None, | |
bm25_b: float = None, | |
analyzer_args: AnalyzerArgs = None, | |
dense_index_path: str = None, | |
encoder_name_or_path: str = None, | |
encoder_class: EncoderClass = None, | |
tokenizer_name: str = None, | |
device: str = None, | |
prefix: str = None | |
) -> Tuple[Union[FaissSearcher, HybridSearcher, LuceneSearcher], IndexReader]: | |
""" | |
Initialize and return an approapriate searcher | |
Parameters | |
---------- | |
sparse_index_path: str | |
Path to sparse index | |
dense_index_path: str | |
Path to dense index | |
encoder_name_or_path: str | |
Path to query encoder checkpoint or encoder name | |
encoder_class: str | |
Query encoder class to use. If None, infer from `encoder` | |
tokenizer_name: str | |
Tokenizer name or path | |
device: str | |
Device to load Query encoder on. | |
prefix: str | |
Query prefix if exists | |
Returns | |
------- | |
Searcher: FaissSearcher | HybridSearcher | LuceneSearcher | |
A sparse, dense or hybrid searcher | |
""" | |
reader = None | |
if sparse_index_path: | |
ssearcher = LuceneSearcher(sparse_index_path) | |
if analyzer_args: | |
analyzer = get_lucene_analyzer(**analyzer_args) | |
ssearcher.set_analyzer(analyzer) | |
if bm25_k1 and bm25_b: | |
ssearcher.set_bm25(bm25_k1, bm25_b) | |
if dense_index_path: | |
encoder = init_query_encoder( | |
encoder=encoder_name_or_path, | |
encoder_class=encoder_class, | |
tokenizer_name=tokenizer_name, | |
topics_name=None, | |
encoded_queries=None, | |
device=device, | |
prefix=prefix | |
) | |
reader = IndexReader(sparse_index_path) | |
dsearcher = FaissSearcher(dense_index_path, encoder) | |
if sparse_index_path: | |
hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher) | |
return hsearcher, reader | |
else: | |
return dsearcher, reader | |
return ssearcher, reader | |
def _search(searcher: Searcher, reader: IndexReader, query: str, num_results: int = 10) -> List[SearchResult]: | |
""" | |
Parameters: | |
----------- | |
searcher: FaissSearcher | HybridSearcher | LuceneSearcher | |
A sparse, dense or hybrid searcher | |
query: str | |
Query for which to retrieve results | |
num_results: int | |
Maximum number of results to retrieve | |
Returns: | |
-------- | |
Dict: | |
""" | |
def _get_dict(r: Union[DenseSearchResult, JLuceneSearcherResult]): | |
if isinstance(r, JLuceneSearcherResult): | |
return json.loads(r.raw) | |
elif isinstance(r, DenseSearchResult): | |
# Get document from sparse_index using index reader | |
return json.loads(reader.doc(r.docid).raw()) | |
search_results = searcher.search(query, k=num_results) | |
all_results = [ | |
SearchResult( | |
docid=result["id"], | |
text=result["contents"], | |
score=search_results[idx].score | |
) for idx, result in enumerate(map(lambda r: _get_dict(r), search_results)) | |
] | |
return all_results | |