|
import torch |
|
import pickle |
|
from transformers import AutoTokenizer , DistilBertForSequenceClassification |
|
from transformers import BatchEncoding, PreTrainedTokenizerBase |
|
from typing import Optional |
|
from torch import Tensor |
|
|
|
|
|
model = DistilBertForSequenceClassification.from_pretrained("DistillMDPI1/DistillMDPI1/saved_model") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("DistillMDPI1/DistillMDPI1/saved_tokenizer") |
|
|
|
|
|
with open("DistillMDPI1/DistillMDPI1/label_encoder.pkl", "rb") as f: |
|
label_encoder = pickle.load(f) |
|
|
|
|
|
class_labels = { |
|
16: ('vehicles','info' , '#4f9ef8'), |
|
10: ('environments','success' , '#0cbc87'), |
|
9: ('energies', 'danger', '#d6293e'), |
|
0: ('Physics', 'primary', '#0f6fec'), |
|
13: ('robotics', 'moss','#B1E5F2'), |
|
3: ('agriculture','agri' , '#a8c686'), |
|
11: ('ML', 'yellow', '#ffc107'), |
|
8: ('economies', 'warning' , '#f7c32e'), |
|
15: ('technologies','vanila' ,'#FDF0D5' ), |
|
12: ('mathematics','coffe' ,'#7f5539' ), |
|
14: ('sports', 'orange', '#fd7e14'), |
|
4: ('AI','cyan', '#0dcaf0'), |
|
6: ('Innovation','rosy' ,'#BF98A0'), |
|
5: ('Science','picton' ,'#5fa8d3' ), |
|
1: ('Societies','purple' , '#6f42c1'), |
|
2: ('administration','pink', '#d63384'), |
|
7: ('biology' ,'cambridge' , '#88aa99')} |
|
|
|
def predict_class(text): |
|
|
|
inputs = transform_list_of_texts(text, tokenizer, 510, 510, 1, 2550) |
|
|
|
input_ids_tensor = inputs["input_ids"][0] |
|
attention_mask_tensor = inputs["attention_mask"][0] |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor) |
|
|
|
|
|
probabilities = torch.softmax(outputs.logits, dim=1)[0] |
|
|
|
|
|
predicted_class_index = torch.argmax(probabilities).item() |
|
predicted_class = class_labels[predicted_class_index] |
|
|
|
|
|
sorted_percentages = {class_labels[idx]: probabilities[idx].item() * 100 for idx in range(len(class_labels))} |
|
sorted_percentages = dict(sorted(sorted_percentages.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
return predicted_class, sorted_percentages |
|
|
|
def transform_list_of_texts( |
|
texts: list[str], |
|
tokenizer: PreTrainedTokenizerBase, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
maximal_text_length: Optional[int] = None, |
|
) -> BatchEncoding: |
|
model_inputs = [ |
|
transform_single_text(text, tokenizer, chunk_size, stride, minimal_chunk_length, maximal_text_length) |
|
for text in texts |
|
] |
|
input_ids = [model_input[0] for model_input in model_inputs] |
|
attention_mask = [model_input[1] for model_input in model_inputs] |
|
tokens = {"input_ids": input_ids, "attention_mask": attention_mask} |
|
return BatchEncoding(tokens) |
|
|
|
|
|
def transform_single_text( |
|
text: str, |
|
tokenizer: PreTrainedTokenizerBase, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
maximal_text_length: Optional[int], |
|
) -> tuple[Tensor, Tensor]: |
|
"""Transforms (the entire) text to model input of BERT model.""" |
|
if maximal_text_length: |
|
tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length) |
|
else: |
|
tokens = tokenize_whole_text(text, tokenizer) |
|
input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length) |
|
add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks) |
|
add_padding_tokens(input_id_chunks, mask_chunks) |
|
input_ids, attention_mask = stack_tokens_from_all_chunks(input_id_chunks, mask_chunks) |
|
return input_ids, attention_mask |
|
|
|
|
|
def tokenize_whole_text(text: str, tokenizer: PreTrainedTokenizerBase) -> BatchEncoding: |
|
"""Tokenizes the entire text without truncation and without special tokens.""" |
|
tokens = tokenizer(text, add_special_tokens=False, truncation=False, return_tensors="pt") |
|
return tokens |
|
|
|
|
|
def tokenize_text_with_truncation( |
|
text: str, tokenizer: PreTrainedTokenizerBase, maximal_text_length: int |
|
) -> BatchEncoding: |
|
"""Tokenizes the text with truncation to maximal_text_length and without special tokens.""" |
|
tokens = tokenizer( |
|
text, add_special_tokens=False, max_length=maximal_text_length, truncation=True, return_tensors="pt" |
|
) |
|
return tokens |
|
|
|
|
|
def split_tokens_into_smaller_chunks( |
|
tokens: BatchEncoding, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
) -> tuple[list[Tensor], list[Tensor]]: |
|
"""Splits tokens into overlapping chunks with given size and stride.""" |
|
input_id_chunks = split_overlapping(tokens["input_ids"][0], chunk_size, stride, minimal_chunk_length) |
|
mask_chunks = split_overlapping(tokens["attention_mask"][0], chunk_size, stride, minimal_chunk_length) |
|
return input_id_chunks, mask_chunks |
|
|
|
|
|
def add_special_tokens_at_beginning_and_end(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None: |
|
""" |
|
Adds special CLS token (token id = 101) at the beginning. |
|
Adds SEP token (token id = 102) at the end of each chunk. |
|
Adds corresponding attention masks equal to 1 (attention mask is boolean). |
|
""" |
|
for i in range(len(input_id_chunks)): |
|
|
|
input_id_chunks[i] = torch.cat([Tensor([101]), input_id_chunks[i], Tensor([102])]) |
|
|
|
mask_chunks[i] = torch.cat([Tensor([1]), mask_chunks[i], Tensor([1])]) |
|
|
|
|
|
def add_padding_tokens(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None: |
|
"""Adds padding tokens (token id = 0) at the end to make sure that all chunks have exactly 512 tokens.""" |
|
for i in range(len(input_id_chunks)): |
|
|
|
pad_len = 512 - input_id_chunks[i].shape[0] |
|
|
|
if pad_len > 0: |
|
|
|
input_id_chunks[i] = torch.cat([input_id_chunks[i], Tensor([0] * pad_len)]) |
|
mask_chunks[i] = torch.cat([mask_chunks[i], Tensor([0] * pad_len)]) |
|
|
|
|
|
def stack_tokens_from_all_chunks(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> tuple[Tensor, Tensor]: |
|
"""Reshapes data to a form compatible with BERT model input.""" |
|
input_ids = torch.stack(input_id_chunks) |
|
attention_mask = torch.stack(mask_chunks) |
|
|
|
return input_ids.long(), attention_mask.int() |
|
|
|
|
|
def split_overlapping(tensor: Tensor, chunk_size: int, stride: int, minimal_chunk_length: int) -> list[Tensor]: |
|
"""Helper function for dividing 1-dimensional tensors into overlapping chunks.""" |
|
result = [tensor[i : i + chunk_size] for i in range(0, len(tensor), stride)] |
|
if len(result) > 1: |
|
|
|
result = [x for x in result if len(x) >= minimal_chunk_length] |
|
return result |
|
|
|
|