from __future__ import annotations from dataclasses import dataclass import pickle import os from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar from nlp4web_codebase.ir.data_loaders.dm import Document from collections import Counter import tqdm import re import nltk nltk.download("stopwords", quiet=True) from nltk.corpus import stopwords as nltk_stopwords LANGUAGE = "english" word_splitter = re.compile(r"(?u)\b\w\w+\b").findall stopwords = set(nltk_stopwords.words(LANGUAGE)) def word_splitting(text: str) -> List[str]: return word_splitter(text.lower()) def lemmatization(words: List[str]) -> List[str]: return words # We ignore lemmatization here for simplicity def simple_tokenize(text: str) -> List[str]: words = word_splitting(text) tokenized = list(filter(lambda w: w not in stopwords, words)) tokenized = lemmatization(tokenized) return tokenized T = TypeVar("T", bound="InvertedIndex") @dataclass class PostingList: term: str # The term docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting @dataclass class InvertedIndex: posting_lists: List[PostingList] # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index # The output of the counting function: @dataclass class Counting: posting_lists: List[PostingList] vocab: Dict[str, int] cid2docid: Dict[str, int] collection_ids: List[str] dfs: List[int] # tid -> df dls: List[int] # docid -> doc length avgdl: float nterms: int doc_texts: Optional[List[str]] = None def run_counting( documents: Iterable[Document], tokenize_fn: Callable[[str], List[str]] = simple_tokenize, store_raw: bool = True, # store the document text in doc_texts ndocs: Optional[int] = None, show_progress_bar: bool = True, ) -> Counting: """Counting TFs, DFs, doc_lengths, etc.""" posting_lists: List[PostingList] = [] vocab: Dict[str, int] = {} cid2docid: Dict[str, int] = {} collection_ids: List[str] = [] dfs: List[int] = [] # tid -> df dls: List[int] = [] # docid -> doc length nterms: int = 0 doc_texts: Optional[List[str]] = [] for doc in tqdm.tqdm( documents, desc="Counting", total=ndocs, disable=not show_progress_bar, ): if doc.collection_id in cid2docid: continue collection_ids.append(doc.collection_id) docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) toks = tokenize_fn(doc.text) tok2tf = Counter(toks) dls.append(sum(tok2tf.values())) for tok, tf in tok2tf.items(): nterms += tf tid = vocab.get(tok, None) if tid is None: posting_lists.append( PostingList(term=tok, docid_postings=[], tweight_postings=[]) ) tid = vocab.setdefault(tok, len(vocab)) posting_lists[tid].docid_postings.append(docid) posting_lists[tid].tweight_postings.append(tf) if tid < len(dfs): dfs[tid] += 1 else: dfs.append(0) if store_raw: doc_texts.append(doc.text) else: doc_texts = None return Counting( posting_lists=posting_lists, vocab=vocab, cid2docid=cid2docid, collection_ids=collection_ids, dfs=dfs, dls=dls, avgdl=sum(dls) / len(dls), nterms=nterms, doc_texts=doc_texts, ) from nlp4web_codebase.ir.data_loaders.sciq import load_sciq sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) from dataclasses import asdict, dataclass import math import os from typing import Iterable, List, Optional, Type import tqdm from nlp4web_codebase.ir.data_loaders.dm import Document @dataclass class BM25Index(InvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> None: """Compute term weights and caching""" N = total_docs for tid, posting_list in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): idf = BM25Index.calc_idf(df=dfs[tid], N=N) for i in range(len(posting_list.docid_postings)): docid = posting_list.docid_postings[i] tf = posting_list.tweight_postings[i] dl = dls[docid] regularized_tf = BM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) posting_list.tweight_postings[i] = regularized_tf * idf @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type[BM25Index], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> BM25Index: # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=BM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) BM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = BM25Index( posting_lists=posting_lists, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, ) bm25_index.save("output/bm25_index") from nlp4web_codebase.ir.models import BaseRetriever from typing import Type from abc import abstractmethod class BaseInvertedIndexRetriever(BaseRetriever): @property @abstractmethod def index_class(self) -> Type[InvertedIndex]: pass def __init__(self, index_dir: str) -> None: self.index = self.index_class.from_saved(index_dir) def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: toks = self.index.tokenize(query) target_docid = self.index.cid2docid[cid] term_weights = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists[tid] for docid, tweight in zip( posting_list.docid_postings, posting_list.tweight_postings ): if docid == target_docid: term_weights[tok] = tweight break return term_weights def score(self, query: str, cid: str) -> float: return sum(self.get_term_weights(query=query, cid=cid).values()) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: toks = self.index.tokenize(query) docid2score: Dict[int, float] = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists[tid] for docid, tweight in zip( posting_list.docid_postings, posting_list.tweight_postings ): docid2score.setdefault(docid, 0) docid2score[docid] += tweight docid2score = dict( sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] ) return { self.index.collection_ids[docid]: score for docid, score in docid2score.items() } class BM25Retriever(BaseInvertedIndexRetriever): @property def index_class(self) -> Type[BM25Index]: return BM25Index from nlp4web_codebase.ir.data_loaders import Split import pytrec_eval import numpy as np def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float: metric = "map_cut_10" qrels = sciq.get_qrels_dict(split) evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,)) qps = evaluator.evaluate(rankings) return float(np.mean([qp[metric] for qp in qps.values()])) # Loading dataset: from nlp4web_codebase.ir.data_loaders.sciq import load_sciq sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) # Building BM25 index and save: bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True ) bm25_index.save("output/bm25_index") plots_b: Dict[str, List[float]] = { "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], "Y": [] } plots_k1: Dict[str, List[float]] = { "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], "Y": [] } ## YOUR_CODE_STARTS_HERE # Two steps should be involved: # Step 1. Fix k1 value to the default one 0.9, # go through all the candidate b values (0, 0.1, ..., 1.0), # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map; # Step 2. Fix b to the best one in step 1. and do the same for k1. # Hint (on using the pre-requisite code): # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code); # - One can build bm25_index with `BM25Index.build_from_documents`; # - One can use BM25Retriever to load the index and perform retrieval on the dev queries # (dev queries can be obtained via sciq.get_split_queries(Split.dev)) for b in plots_b["X"]: bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=False, k1=0.9, b=b ) bm25_index.save("output/bm25_index") bm25_retriever = BM25Retriever(index_dir="output/bm25_index") rankings = {} for query in sciq.get_split_queries(Split.dev): ranking = bm25_retriever.retrieve(query=query.text) rankings[query.query_id] = ranking k1_b_map = evaluate_map(rankings, split=Split.dev) plots_b["Y"].append(k1_b_map) best_b = plots_b["X"][np.argmax(plots_b["Y"])] for k1 in plots_k1["X"]: bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=False, k1=k1, b=best_b ) bm25_index.save("output/bm25_index") bm25_retriever = BM25Retriever(index_dir="output/bm25_index") rankings = {} for query in sciq.get_split_queries(Split.dev): ranking = bm25_retriever.retrieve(query=query.text) rankings[query.query_id] = ranking k1_b_map = evaluate_map(rankings, split=Split.dev) plots_k1["Y"].append(k1_b_map) best_b = plots_b["X"][np.argmax(plots_b["Y"])] best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])] from scipy.sparse._csc import csc_matrix @dataclass class CSCInvertedIndex: posting_lists_matrix: csc_matrix # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index @dataclass class CSCBM25Index(CSCInvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> csc_matrix: """Compute term weights and caching""" ## YOUR_CODE_STARTS_HERE # total_terms = len(posting_lists) # matrix = np.zeros((total_terms, total_docs)) # N = total_docs # for tid, posting_list in enumerate( # tqdm.tqdm(posting_lists, desc="Regularizing TFs") # ): # df = dfs[tid] # idf = CSCBM25Index.calc_idf(df, N) # for i in range(len(posting_list.docid_postings)): # docid = posting_list.docid_postings[i] # dl = dls[docid] # tf = posting_list.tweight_postings[i] # regularized_tf = CSCBM25Index.calc_regularized_tf(tf, dl, avgdl, k1, b) # new_weight = regularized_tf * idf # posting_list.tweight_postings[i] = new_weight # matrix[tid][docid] = new_weight # posting_lists_matrix = csc_matrix(matrix) # return posting_lists_matrix # total_terms = len(posting_lists) # matrix = np.zeros((total_docs, total_terms)) # N = total_docs # for tid, posting_list in enumerate( # tqdm.tqdm(posting_lists, desc="Regularizing TFs") # ): # df = dfs[tid] # Document Frequency für jeden Term # idf = CSCBM25Index.calc_idf(df, N) # for i in range(len(posting_list.docid_postings)): # docid = posting_list.docid_postings[i] # dl = dls[docid] # tf = posting_list.tweight_postings[i] # regularized_tf = CSCBM25Index.calc_regularized_tf(tf, dl, avgdl, k1, b) # new_weight = regularized_tf * idf # # posting_list.tweight_postings[i] = new_weight # matrix[docid][tid] = new_weight # posting_lists_matrix = csc_matrix(matrix) # return posting_lists_matrix data_tweights = [] row_ind = [] col_ind = [] shape = (total_docs, len(posting_lists)) N = total_docs for tid, posting_el in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): df = dfs[tid] idf = CSCBM25Index.calc_idf(df, N) for i in range(len(posting_el.docid_postings)): docid = posting_el.docid_postings[i] dl = dls[docid] tf = posting_el.tweight_postings[i] regularized_tf = CSCBM25Index.calc_regularized_tf(tf, dl, avgdl, k1, b) new_weight = regularized_tf * idf data_tweights.append(new_weight) col_ind.append(tid) row_ind.append(docid) posting_lists_matrix = csc_matrix((data_tweights, (row_ind, col_ind)), shape, dtype=np.float32) return posting_lists_matrix ## YOUR_CODE_ENDS_HERE @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type[CSCBM25Index], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> CSCBM25Index: # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=CSCBM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) posting_lists_matrix = CSCBM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = CSCBM25Index( posting_lists_matrix=posting_lists_matrix, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index csc_bm25_index = CSCBM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, k1=best_k1, b=best_b ) csc_bm25_index.save("output/csc_bm25_index") class BaseCSCInvertedIndexRetriever(BaseRetriever): @property @abstractmethod def index_class(self) -> Type[CSCInvertedIndex]: pass def __init__(self, index_dir: str) -> None: self.index = self.index_class.from_saved(index_dir) def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE # toks = self.index.tokenize(query) # term_weight = {} # docid = self.index.cid2docid[cid] # csc_output = self.index.posting_lists_matrix.getcol(docid) # for tok in toks: # if tok not in self.index.vocab: # continue # tid = self.index.vocab[tok] # for id, tweight in zip(csc_output.indices, csc_output.data): # if id == tid: # term_weight[tok] = tweight # continue # return term_weight toks = self.index.tokenize(query) term_weight = {} docid = self.index.cid2docid[cid] csc_output = self.index.posting_lists_matrix.getrow(docid) for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] for id, tweight in zip(csc_output.indices, csc_output.data): if id == tid: term_weight[tok] = tweight continue return term_weight ## YOUR_CODE_ENDS_HERE def score(self, query: str, cid: str) -> float: return sum(self.get_term_weights(query=query, cid=cid).values()) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE ranking: Dict[str, float] = {} toks = self.index.tokenize(query) docid2score: Dict[int, float] = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] tid2documents = self.index.posting_lists_matrix.getcol(tid) for docid, tweight in zip(tid2documents.indices, tid2documents.data): docid2score.setdefault(docid, 0) docid2score[docid] += tweight docid2score = dict( sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] ) ranking = { self.index.collection_ids[docid]: score for docid, score in docid2score.items() } return ranking ## YOUR_CODE_ENDS_HERE class CSCBM25Retriever(BaseCSCInvertedIndexRetriever): @property def index_class(self) -> Type[CSCBM25Index]: return CSCBM25Index import gradio as gr from typing import TypedDict class Hit(TypedDict): cid: str score: float text: str demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable return_type = List[Hit] ## YOUR_CODE_STARTS_HERE def search(query) -> List[Hit]: return_type: List[Hit] = [] bm_25_retriever = BM25Retriever(index_dir="output/bm25_index") ranking = bm_25_retriever.retrieve(query) for rank in ranking: # print(rank, ranking[rank]) # print(bm_25_retriever.index.cid2docid[rank]) # print(bm_25_retriever.index.doc_texts[bm_25_retriever.index.cid2docid[rank]]) hit = { "cid": rank, "score": ranking[rank], "text": bm_25_retriever.index.doc_texts[bm_25_retriever.index.cid2docid[rank]] } return_type.append(hit) return return_type demo = gr.Interface( fn=search, inputs=["text"], outputs=gr.Textbox() ) ## YOUR_CODE_ENDS_HERE demo.launch(share=True)