import gradio as gr from typing import TypedDict from typing import Dict, List from datasets import load_dataset import joblib from dataclasses import dataclass from enum import Enum from typing import Dict, List from dataclasses import dataclass from typing import Optional @dataclass class Document: collection_id: str text: str @dataclass class Query: query_id: str text: str @dataclass class QRel: query_id: str collection_id: str relevance: int answer: Optional[str] = None class Split(str, Enum): train = "train" dev = "dev" test = "test" @dataclass class IRDataset: corpus: List[Document] queries: List[Query] split2qrels: Dict[Split, List[QRel]] def get_stats(self) -> Dict[str, int]: stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)} for split, qrels in self.split2qrels.items(): stats[f"|qrels-{split}|"] = len(qrels) return stats def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]: qrels_dict = {} for qrel in self.split2qrels[split]: qrels_dict.setdefault(qrel.query_id, {}) qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance return qrels_dict def get_split_queries(self, split: Split) -> List[Query]: qrels = self.split2qrels[split] qids = {qrel.query_id for qrel in qrels} return list(filter(lambda query: query.query_id in qids, self.queries)) @(joblib.Memory(".cache").cache) def load_sciq(verbose: bool = False) -> IRDataset: train = load_dataset("allenai/sciq", split="train") validation = load_dataset("allenai/sciq", split="validation") test = load_dataset("allenai/sciq", split="test") data = {Split.train: train, Split.dev: validation, Split.test: test} # Each duplicated record is the same to each other: df = train.to_pandas() + validation.to_pandas() + test.to_pandas() for question, group in df.groupby("question"): assert len(set(group["support"].tolist())) == len(group) assert len(set(group["correct_answer"].tolist())) == len(group) # Build: corpus = [] queries = [] split2qrels: Dict[str, List[dict]] = {} question2id = {} support2id = {} for split, rows in data.items(): if verbose: print(f"|raw_{split}|", len(rows)) split2qrels[split] = [] for i, row in enumerate(rows): example_id = f"{split}-{i}" support: str = row["support"] if len(support.strip()) == 0: continue question = row["question"] if len(support.strip()) == 0: continue if support in support2id: continue else: support2id[support] = example_id if question in question2id: continue else: question2id[question] = example_id doc = {"collection_id": example_id, "text": support} query = {"query_id": example_id, "text": row["question"]} qrel = { "query_id": example_id, "collection_id": example_id, "relevance": 1, "answer": row["correct_answer"], } corpus.append(Document(**doc)) queries.append(Query(**query)) split2qrels[split].append(QRel(**qrel)) # Assembly and return: return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels) if __name__ == "__main__": # python -m nlp4web_codebase.ir.data_loaders.sciq import ujson import time start = time.time() dataset = load_sciq(verbose=True) print(f"Loading costs: {time.time() - start}s") print(ujson.dumps(dataset.get_stats(), indent=4)) # ________________________________________________________________________________ # [Memory] Calling __main__--home-kwang-research-nlp4web-ir-exercise-nlp4web-nlp4web-ir-data_loaders-sciq.load_sciq... # load_sciq(verbose=True) # |raw_train| 11679 # |raw_dev| 1000 # |raw_test| 1000 # ________________________________________________________load_sciq - 7.3s, 0.1min # Loading costs: 7.260092735290527s # { # "|corpus|": 12160, # "|queries|": 12160, # "|qrels-train|": 10409, # "|qrels-dev|": 875, # "|qrels-test|": 876 # } class Hit(TypedDict): cid: str score: float text: str ## YOUR_CODE_STARTS_HERE def search(query: str) -> List[Hit]: sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True ) bm25_index.save("output/bm25_index") bm25_retriever = BM25Retriever(index_dir="output/bm25_index") results = bm25_retriever.retrieve(query=query) hits: List[Hit] = [] for cid, score in results.items(): docid = bm25_retriever.index.cid2docid[cid] text = bm25_retriever.index.doc_texts[docid] hits.append({"cid": cid, "score": score, "text": text}) return hits ## YOUR_CODE_ENDS_HERE demo: Optional[gr.Interface] = gr.Interface( fn=search, inputs=gr.Textbox(label="Query"), outputs=gr.JSON(label="Results") ) # Assign your gradio demo to this variable return_type = List[Hit] demo.launch()