File size: 5,472 Bytes
8a30d86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import gradio as gr
from typing import TypedDict
from typing import Dict, List
from datasets import load_dataset
import joblib
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List
from dataclasses import dataclass
from typing import Optional
@dataclass
class Document:
collection_id: str
text: str
@dataclass
class Query:
query_id: str
text: str
@dataclass
class QRel:
query_id: str
collection_id: str
relevance: int
answer: Optional[str] = None
class Split(str, Enum):
train = "train"
dev = "dev"
test = "test"
@dataclass
class IRDataset:
corpus: List[Document]
queries: List[Query]
split2qrels: Dict[Split, List[QRel]]
def get_stats(self) -> Dict[str, int]:
stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
for split, qrels in self.split2qrels.items():
stats[f"|qrels-{split}|"] = len(qrels)
return stats
def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
qrels_dict = {}
for qrel in self.split2qrels[split]:
qrels_dict.setdefault(qrel.query_id, {})
qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
return qrels_dict
def get_split_queries(self, split: Split) -> List[Query]:
qrels = self.split2qrels[split]
qids = {qrel.query_id for qrel in qrels}
return list(filter(lambda query: query.query_id in qids, self.queries))
@(joblib.Memory(".cache").cache)
def load_sciq(verbose: bool = False) -> IRDataset:
train = load_dataset("allenai/sciq", split="train")
validation = load_dataset("allenai/sciq", split="validation")
test = load_dataset("allenai/sciq", split="test")
data = {Split.train: train, Split.dev: validation, Split.test: test}
# Each duplicated record is the same to each other:
df = train.to_pandas() + validation.to_pandas() + test.to_pandas()
for question, group in df.groupby("question"):
assert len(set(group["support"].tolist())) == len(group)
assert len(set(group["correct_answer"].tolist())) == len(group)
# Build:
corpus = []
queries = []
split2qrels: Dict[str, List[dict]] = {}
question2id = {}
support2id = {}
for split, rows in data.items():
if verbose:
print(f"|raw_{split}|", len(rows))
split2qrels[split] = []
for i, row in enumerate(rows):
example_id = f"{split}-{i}"
support: str = row["support"]
if len(support.strip()) == 0:
continue
question = row["question"]
if len(support.strip()) == 0:
continue
if support in support2id:
continue
else:
support2id[support] = example_id
if question in question2id:
continue
else:
question2id[question] = example_id
doc = {"collection_id": example_id, "text": support}
query = {"query_id": example_id, "text": row["question"]}
qrel = {
"query_id": example_id,
"collection_id": example_id,
"relevance": 1,
"answer": row["correct_answer"],
}
corpus.append(Document(**doc))
queries.append(Query(**query))
split2qrels[split].append(QRel(**qrel))
# Assembly and return:
return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
if __name__ == "__main__":
# python -m nlp4web_codebase.ir.data_loaders.sciq
import ujson
import time
start = time.time()
dataset = load_sciq(verbose=True)
print(f"Loading costs: {time.time() - start}s")
print(ujson.dumps(dataset.get_stats(), indent=4))
# ________________________________________________________________________________
# [Memory] Calling __main__--home-kwang-research-nlp4web-ir-exercise-nlp4web-nlp4web-ir-data_loaders-sciq.load_sciq...
# load_sciq(verbose=True)
# |raw_train| 11679
# |raw_dev| 1000
# |raw_test| 1000
# ________________________________________________________load_sciq - 7.3s, 0.1min
# Loading costs: 7.260092735290527s
# {
# "|corpus|": 12160,
# "|queries|": 12160,
# "|qrels-train|": 10409,
# "|qrels-dev|": 875,
# "|qrels-test|": 876
# }
class Hit(TypedDict):
cid: str
score: float
text: str
## YOUR_CODE_STARTS_HERE
def search(query: str) -> List[Hit]:
sciq = load_sciq()
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
bm25_index = BM25Index.build_from_documents(
documents=iter(sciq.corpus),
ndocs=12160,
show_progress_bar=True
)
bm25_index.save("output/bm25_index")
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
results = bm25_retriever.retrieve(query=query)
hits: List[Hit] = []
for cid, score in results.items():
docid = bm25_retriever.index.cid2docid[cid]
text = bm25_retriever.index.doc_texts[docid]
hits.append({"cid": cid, "score": score, "text": text})
return hits
## YOUR_CODE_ENDS_HERE
demo: Optional[gr.Interface] = gr.Interface(
fn=search,
inputs=gr.Textbox(label="Query"),
outputs=gr.JSON(label="Results")
) # Assign your gradio demo to this variable
return_type = List[Hit]
demo.launch() |