|
import gradio as gr |
|
from typing import TypedDict |
|
from typing import Dict, List |
|
from datasets import load_dataset |
|
import joblib |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import Dict, List |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
|
|
@dataclass |
|
class Document: |
|
collection_id: str |
|
text: str |
|
|
|
|
|
@dataclass |
|
class Query: |
|
query_id: str |
|
text: str |
|
|
|
|
|
@dataclass |
|
class QRel: |
|
query_id: str |
|
collection_id: str |
|
relevance: int |
|
answer: Optional[str] = None |
|
|
|
|
|
|
|
class Split(str, Enum): |
|
train = "train" |
|
dev = "dev" |
|
test = "test" |
|
|
|
|
|
@dataclass |
|
class IRDataset: |
|
corpus: List[Document] |
|
queries: List[Query] |
|
split2qrels: Dict[Split, List[QRel]] |
|
|
|
def get_stats(self) -> Dict[str, int]: |
|
stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)} |
|
for split, qrels in self.split2qrels.items(): |
|
stats[f"|qrels-{split}|"] = len(qrels) |
|
return stats |
|
|
|
def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]: |
|
qrels_dict = {} |
|
for qrel in self.split2qrels[split]: |
|
qrels_dict.setdefault(qrel.query_id, {}) |
|
qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance |
|
return qrels_dict |
|
|
|
def get_split_queries(self, split: Split) -> List[Query]: |
|
qrels = self.split2qrels[split] |
|
qids = {qrel.query_id for qrel in qrels} |
|
return list(filter(lambda query: query.query_id in qids, self.queries)) |
|
|
|
|
|
@(joblib.Memory(".cache").cache) |
|
def load_sciq(verbose: bool = False) -> IRDataset: |
|
train = load_dataset("allenai/sciq", split="train") |
|
validation = load_dataset("allenai/sciq", split="validation") |
|
test = load_dataset("allenai/sciq", split="test") |
|
data = {Split.train: train, Split.dev: validation, Split.test: test} |
|
|
|
|
|
df = train.to_pandas() + validation.to_pandas() + test.to_pandas() |
|
for question, group in df.groupby("question"): |
|
assert len(set(group["support"].tolist())) == len(group) |
|
assert len(set(group["correct_answer"].tolist())) == len(group) |
|
|
|
|
|
corpus = [] |
|
queries = [] |
|
split2qrels: Dict[str, List[dict]] = {} |
|
question2id = {} |
|
support2id = {} |
|
for split, rows in data.items(): |
|
if verbose: |
|
print(f"|raw_{split}|", len(rows)) |
|
split2qrels[split] = [] |
|
for i, row in enumerate(rows): |
|
example_id = f"{split}-{i}" |
|
support: str = row["support"] |
|
if len(support.strip()) == 0: |
|
continue |
|
question = row["question"] |
|
if len(support.strip()) == 0: |
|
continue |
|
if support in support2id: |
|
continue |
|
else: |
|
support2id[support] = example_id |
|
if question in question2id: |
|
continue |
|
else: |
|
question2id[question] = example_id |
|
doc = {"collection_id": example_id, "text": support} |
|
query = {"query_id": example_id, "text": row["question"]} |
|
qrel = { |
|
"query_id": example_id, |
|
"collection_id": example_id, |
|
"relevance": 1, |
|
"answer": row["correct_answer"], |
|
} |
|
corpus.append(Document(**doc)) |
|
queries.append(Query(**query)) |
|
split2qrels[split].append(QRel(**qrel)) |
|
|
|
|
|
return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import ujson |
|
import time |
|
|
|
start = time.time() |
|
dataset = load_sciq(verbose=True) |
|
print(f"Loading costs: {time.time() - start}s") |
|
print(ujson.dumps(dataset.get_stats(), indent=4)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Hit(TypedDict): |
|
cid: str |
|
score: float |
|
text: str |
|
|
|
|
|
def search(query: str) -> List[Hit]: |
|
|
|
sciq = load_sciq() |
|
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) |
|
|
|
bm25_index = BM25Index.build_from_documents( |
|
documents=iter(sciq.corpus), |
|
ndocs=12160, |
|
show_progress_bar=True |
|
) |
|
bm25_index.save("output/bm25_index") |
|
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
|
results = bm25_retriever.retrieve(query=query) |
|
|
|
hits: List[Hit] = [] |
|
for cid, score in results.items(): |
|
docid = bm25_retriever.index.cid2docid[cid] |
|
text = bm25_retriever.index.doc_texts[docid] |
|
hits.append({"cid": cid, "score": score, "text": text}) |
|
|
|
return hits |
|
|
|
|
|
demo: Optional[gr.Interface] = gr.Interface( |
|
fn=search, |
|
inputs=gr.Textbox(label="Query"), |
|
outputs=gr.JSON(label="Results") |
|
) |
|
return_type = List[Hit] |
|
demo.launch() |