import torch import pickle from transformers import AutoTokenizer , DistilBertForSequenceClassification from transformers import BatchEncoding, PreTrainedTokenizerBase from typing import Optional from torch import Tensor # Load the model model = DistilBertForSequenceClassification.from_pretrained("DistillMDPI1/DistillMDPI1/saved_model") # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained("DistillMDPI1/DistillMDPI1/saved_tokenizer") # Charger le label encoder with open("DistillMDPI1/DistillMDPI1/label_encoder.pkl", "rb") as f: label_encoder = pickle.load(f) class_labels = { 16: ('vehicles','info' , '#4f9ef8'), 10: ('environments','success' , '#0cbc87'), 9: ('energies', 'danger', '#d6293e'), 0: ('Physics', 'primary', '#0f6fec'), 13: ('robotics', 'moss','#B1E5F2'), 3: ('agriculture','agri' , '#a8c686'), 11: ('ML', 'yellow', '#ffc107'), 8: ('economies', 'warning' , '#f7c32e'), 15: ('technologies','vanila' ,'#FDF0D5' ), 12: ('mathematics','coffe' ,'#7f5539' ), 14: ('sports', 'orange', '#fd7e14'), 4: ('AI','cyan', '#0dcaf0'), 6: ('Innovation','rosy' ,'#BF98A0'), 5: ('Science','picton' ,'#5fa8d3' ), 1: ('Societies','purple' , '#6f42c1'), 2: ('administration','pink', '#d63384'), 7: ('biology' ,'cambridge' , '#88aa99')} def predict_class(text): # Tokenisation du texte inputs = transform_list_of_texts(text, tokenizer, 510, 510, 1, 2550) # Extraire le tenseur de la liste input_ids_tensor = inputs["input_ids"][0] attention_mask_tensor = inputs["attention_mask"][0] # Passage du texte à travers le modèle with torch.no_grad(): outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor) # Application de la fonction softmax probabilities = torch.softmax(outputs.logits, dim=1)[0] # Identification de la classe majoritaire predicted_class_index = torch.argmax(probabilities).item() predicted_class = class_labels[predicted_class_index] # Créer un dictionnaire de pourcentages trié par probabilité sorted_percentages = {class_labels[idx]: probabilities[idx].item() * 100 for idx in range(len(class_labels))} sorted_percentages = dict(sorted(sorted_percentages.items(), key=lambda item: item[1], reverse=True)) return predicted_class, sorted_percentages def transform_list_of_texts( texts: list[str], tokenizer: PreTrainedTokenizerBase, chunk_size: int, stride: int, minimal_chunk_length: int, maximal_text_length: Optional[int] = None, ) -> BatchEncoding: model_inputs = [ transform_single_text(text, tokenizer, chunk_size, stride, minimal_chunk_length, maximal_text_length) for text in texts ] input_ids = [model_input[0] for model_input in model_inputs] attention_mask = [model_input[1] for model_input in model_inputs] tokens = {"input_ids": input_ids, "attention_mask": attention_mask} return BatchEncoding(tokens) def transform_single_text( text: str, tokenizer: PreTrainedTokenizerBase, chunk_size: int, stride: int, minimal_chunk_length: int, maximal_text_length: Optional[int], ) -> tuple[Tensor, Tensor]: """Transforms (the entire) text to model input of BERT model.""" if maximal_text_length: tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length) else: tokens = tokenize_whole_text(text, tokenizer) input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length) add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks) add_padding_tokens(input_id_chunks, mask_chunks) input_ids, attention_mask = stack_tokens_from_all_chunks(input_id_chunks, mask_chunks) return input_ids, attention_mask def tokenize_whole_text(text: str, tokenizer: PreTrainedTokenizerBase) -> BatchEncoding: """Tokenizes the entire text without truncation and without special tokens.""" tokens = tokenizer(text, add_special_tokens=False, truncation=False, return_tensors="pt") return tokens def tokenize_text_with_truncation( text: str, tokenizer: PreTrainedTokenizerBase, maximal_text_length: int ) -> BatchEncoding: """Tokenizes the text with truncation to maximal_text_length and without special tokens.""" tokens = tokenizer( text, add_special_tokens=False, max_length=maximal_text_length, truncation=True, return_tensors="pt" ) return tokens def split_tokens_into_smaller_chunks( tokens: BatchEncoding, chunk_size: int, stride: int, minimal_chunk_length: int, ) -> tuple[list[Tensor], list[Tensor]]: """Splits tokens into overlapping chunks with given size and stride.""" input_id_chunks = split_overlapping(tokens["input_ids"][0], chunk_size, stride, minimal_chunk_length) mask_chunks = split_overlapping(tokens["attention_mask"][0], chunk_size, stride, minimal_chunk_length) return input_id_chunks, mask_chunks def add_special_tokens_at_beginning_and_end(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None: """ Adds special CLS token (token id = 101) at the beginning. Adds SEP token (token id = 102) at the end of each chunk. Adds corresponding attention masks equal to 1 (attention mask is boolean). """ for i in range(len(input_id_chunks)): # adding CLS (token id 101) and SEP (token id 102) tokens input_id_chunks[i] = torch.cat([Tensor([101]), input_id_chunks[i], Tensor([102])]) # adding attention masks corresponding to special tokens mask_chunks[i] = torch.cat([Tensor([1]), mask_chunks[i], Tensor([1])]) def add_padding_tokens(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None: """Adds padding tokens (token id = 0) at the end to make sure that all chunks have exactly 512 tokens.""" for i in range(len(input_id_chunks)): # get required padding length pad_len = 512 - input_id_chunks[i].shape[0] # check if tensor length satisfies required chunk size if pad_len > 0: # if padding length is more than 0, we must add padding input_id_chunks[i] = torch.cat([input_id_chunks[i], Tensor([0] * pad_len)]) mask_chunks[i] = torch.cat([mask_chunks[i], Tensor([0] * pad_len)]) def stack_tokens_from_all_chunks(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> tuple[Tensor, Tensor]: """Reshapes data to a form compatible with BERT model input.""" input_ids = torch.stack(input_id_chunks) attention_mask = torch.stack(mask_chunks) return input_ids.long(), attention_mask.int() def split_overlapping(tensor: Tensor, chunk_size: int, stride: int, minimal_chunk_length: int) -> list[Tensor]: """Helper function for dividing 1-dimensional tensors into overlapping chunks.""" result = [tensor[i : i + chunk_size] for i in range(0, len(tensor), stride)] if len(result) > 1: # ignore chunks with less than minimal_length number of tokens result = [x for x in result if len(x) >= minimal_chunk_length] return result