NLP / app.py
moritz648
t
8a30d86
raw
history blame
5.47 kB
import gradio as gr
from typing import TypedDict
from typing import Dict, List
from datasets import load_dataset
import joblib
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List
from dataclasses import dataclass
from typing import Optional
@dataclass
class Document:
collection_id: str
text: str
@dataclass
class Query:
query_id: str
text: str
@dataclass
class QRel:
query_id: str
collection_id: str
relevance: int
answer: Optional[str] = None
class Split(str, Enum):
train = "train"
dev = "dev"
test = "test"
@dataclass
class IRDataset:
corpus: List[Document]
queries: List[Query]
split2qrels: Dict[Split, List[QRel]]
def get_stats(self) -> Dict[str, int]:
stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
for split, qrels in self.split2qrels.items():
stats[f"|qrels-{split}|"] = len(qrels)
return stats
def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
qrels_dict = {}
for qrel in self.split2qrels[split]:
qrels_dict.setdefault(qrel.query_id, {})
qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
return qrels_dict
def get_split_queries(self, split: Split) -> List[Query]:
qrels = self.split2qrels[split]
qids = {qrel.query_id for qrel in qrels}
return list(filter(lambda query: query.query_id in qids, self.queries))
@(joblib.Memory(".cache").cache)
def load_sciq(verbose: bool = False) -> IRDataset:
train = load_dataset("allenai/sciq", split="train")
validation = load_dataset("allenai/sciq", split="validation")
test = load_dataset("allenai/sciq", split="test")
data = {Split.train: train, Split.dev: validation, Split.test: test}
# Each duplicated record is the same to each other:
df = train.to_pandas() + validation.to_pandas() + test.to_pandas()
for question, group in df.groupby("question"):
assert len(set(group["support"].tolist())) == len(group)
assert len(set(group["correct_answer"].tolist())) == len(group)
# Build:
corpus = []
queries = []
split2qrels: Dict[str, List[dict]] = {}
question2id = {}
support2id = {}
for split, rows in data.items():
if verbose:
print(f"|raw_{split}|", len(rows))
split2qrels[split] = []
for i, row in enumerate(rows):
example_id = f"{split}-{i}"
support: str = row["support"]
if len(support.strip()) == 0:
continue
question = row["question"]
if len(support.strip()) == 0:
continue
if support in support2id:
continue
else:
support2id[support] = example_id
if question in question2id:
continue
else:
question2id[question] = example_id
doc = {"collection_id": example_id, "text": support}
query = {"query_id": example_id, "text": row["question"]}
qrel = {
"query_id": example_id,
"collection_id": example_id,
"relevance": 1,
"answer": row["correct_answer"],
}
corpus.append(Document(**doc))
queries.append(Query(**query))
split2qrels[split].append(QRel(**qrel))
# Assembly and return:
return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
if __name__ == "__main__":
# python -m nlp4web_codebase.ir.data_loaders.sciq
import ujson
import time
start = time.time()
dataset = load_sciq(verbose=True)
print(f"Loading costs: {time.time() - start}s")
print(ujson.dumps(dataset.get_stats(), indent=4))
# ________________________________________________________________________________
# [Memory] Calling __main__--home-kwang-research-nlp4web-ir-exercise-nlp4web-nlp4web-ir-data_loaders-sciq.load_sciq...
# load_sciq(verbose=True)
# |raw_train| 11679
# |raw_dev| 1000
# |raw_test| 1000
# ________________________________________________________load_sciq - 7.3s, 0.1min
# Loading costs: 7.260092735290527s
# {
# "|corpus|": 12160,
# "|queries|": 12160,
# "|qrels-train|": 10409,
# "|qrels-dev|": 875,
# "|qrels-test|": 876
# }
class Hit(TypedDict):
cid: str
score: float
text: str
## YOUR_CODE_STARTS_HERE
def search(query: str) -> List[Hit]:
sciq = load_sciq()
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
bm25_index = BM25Index.build_from_documents(
documents=iter(sciq.corpus),
ndocs=12160,
show_progress_bar=True
)
bm25_index.save("output/bm25_index")
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
results = bm25_retriever.retrieve(query=query)
hits: List[Hit] = []
for cid, score in results.items():
docid = bm25_retriever.index.cid2docid[cid]
text = bm25_retriever.index.doc_texts[docid]
hits.append({"cid": cid, "score": score, "text": text})
return hits
## YOUR_CODE_ENDS_HERE
demo: Optional[gr.Interface] = gr.Interface(
fn=search,
inputs=gr.Textbox(label="Query"),
outputs=gr.JSON(label="Results")
) # Assign your gradio demo to this variable
return_type = List[Hit]
demo.launch()