|
from typing import Dict, List, Any |
|
from transformers import AutoConfig, AutoTokenizer |
|
from src.models import DNikudModel, ModelConfig |
|
from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN |
|
from src.utiles_data import Nikud, NikudDataset |
|
from src.models_utils import predict_single, predict |
|
import torch |
|
import os |
|
from tqdm import tqdm |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("tau/tavbert-he") |
|
dir_model_config = os.path.join("models", "config.yml") |
|
self.config = ModelConfig.load_from_file(dir_model_config) |
|
self.model = DNikudModel( |
|
self.config, |
|
len(Nikud.label_2_id["nikud"]), |
|
len(Nikud.label_2_id["dagesh"]), |
|
len(Nikud.label_2_id["sin"]), |
|
device=self.DEVICE, |
|
).to(self.DEVICE) |
|
state_dict_model = self.model.state_dict() |
|
state_dict_model.update(torch.load("./models/Dnikud_best_model.pth")) |
|
self.model.load_state_dict(state_dict_model) |
|
self.max_length = MAX_LENGTH_SEN |
|
|
|
def back_2_text(self, labels, text): |
|
nikud = Nikud() |
|
new_line = "" |
|
|
|
for indx_char, c in enumerate(text): |
|
new_line += ( |
|
c |
|
+ nikud.id_2_char(labels[indx_char][1][1], "dagesh") |
|
+ nikud.id_2_char(labels[indx_char][1][2], "sin") |
|
+ nikud.id_2_char(labels[indx_char][1][0], "nikud") |
|
) |
|
print(indx_char, c) |
|
print(labels) |
|
return new_line |
|
|
|
def prepare_data(self, data, name="train"): |
|
print("Data = ", data) |
|
dataset = [] |
|
for index, (sentence, label) in tqdm( |
|
enumerate(data), desc=f"Prepare data {name}" |
|
): |
|
encoded_sequence = self.tokenizer.encode_plus( |
|
sentence, |
|
add_special_tokens=True, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
label_lists = [ |
|
[letter.nikud, letter.dagesh, letter.sin] for letter in label |
|
] |
|
label = torch.tensor( |
|
[ |
|
[ |
|
Nikud.PAD_OR_IRRELEVANT, |
|
Nikud.PAD_OR_IRRELEVANT, |
|
Nikud.PAD_OR_IRRELEVANT, |
|
] |
|
] |
|
+ label_lists[: (self.max_length - 1)] |
|
+ [ |
|
[ |
|
Nikud.PAD_OR_IRRELEVANT, |
|
Nikud.PAD_OR_IRRELEVANT, |
|
Nikud.PAD_OR_IRRELEVANT, |
|
] |
|
for i in range(self.max_length - len(label) - 1) |
|
] |
|
) |
|
|
|
dataset.append( |
|
( |
|
encoded_sequence["input_ids"][0], |
|
encoded_sequence["attention_mask"][0], |
|
label, |
|
) |
|
) |
|
|
|
self.prepered_data = dataset |
|
|
|
def predict_single_text( |
|
self, |
|
text, |
|
): |
|
dataset = NikudDataset(tokenizer=self.tokenizer, max_length=MAX_LENGTH_SEN) |
|
data, orig_data = dataset.read_single_text(text) |
|
print("data", data, len(data)) |
|
dataset.prepare_data(name="inference") |
|
mtb_prediction_dl = torch.utils.data.DataLoader( |
|
dataset.prepered_data, batch_size=BATCH_SIZE |
|
) |
|
|
|
|
|
all_labels = predict(self.model, mtb_prediction_dl, self.DEVICE) |
|
text_data_with_labels = dataset.back_2_text(labels=all_labels) |
|
|
|
return text_data_with_labels |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
""" |
|
|
|
|
|
inputs = data.pop("text", data) |
|
|
|
|
|
prediction = self.predict_single_text(inputs) |
|
|
|
|
|
|
|
|
|
|
|
return prediction |
|
|