|
import torch |
|
from transformers import AutoTokenizer, BigBirdForQuestionAnswering |
|
from datasets import load_dataset |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base") |
|
model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-roberta-base") |
|
squad_ds = load_dataset("squad_v2", split="train") |
|
|
|
LONG_ARTICLE = squad_ds[81514]["context"] |
|
QUESTION = squad_ds[81514]["question"] |
|
QUESTION |
|
|
|
inputs = tokenizer(QUESTION, LONG_ARTICLE, return_tensors="pt") |
|
|
|
list(inputs["input_ids"].shape) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
answer_start_index = outputs.start_logits.argmax() |
|
answer_end_index = outputs.end_logits.argmax() |
|
predict_answer_token_ids = inputs.input_ids[0, answer_start_index : answer_end_index + 1] |
|
predict_answer_token = tokenizer.decode(predict_answer_token_ids) |