Spaces:
Runtime error
Runtime error
import torch | |
from bert.tokenize import extract_inputs_masks, tokenize_encode_corpus | |
from torch.utils.data import TensorDataset, DataLoader | |
def predict(samples, tokenizer, scaler, model, device, max_len, batch_size, | |
return_scaled=False): | |
model.eval() | |
encoded_corpus = tokenize_encode_corpus(tokenizer, samples, max_len) | |
input_ids, attention_mask = extract_inputs_masks(encoded_corpus) | |
input_ids = torch.tensor([input_ids]).to(device)[0] | |
attention_mask = torch.tensor([attention_mask]).to(device)[0] | |
dataset = TensorDataset(input_ids, attention_mask) | |
dataloader = DataLoader(dataset, batch_size) | |
output = [] | |
for batch in dataloader: | |
batch_inputs, batch_masks = tuple(b.to(device) for b in batch) | |
with torch.no_grad(): | |
output += model(batch_inputs, batch_masks).view(1,-1).tolist()[0] | |
if return_scaled: | |
return output | |
output = scaler.inverse_transform([output]) | |
return output.reshape(1,-1).tolist()[0] |