D_Nikud / run handler.py
AbdulRahim07's picture
Rename handler.py to run handler.py
b4febfe verified
raw
history blame
No virus
4.41 kB
from typing import Dict, List, Any
from transformers import AutoConfig, AutoTokenizer
from src.models import DNikudModel, ModelConfig
from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
from src.utiles_data import Nikud, NikudDataset
from src.models_utils import predict_single, predict
import torch
import os
from tqdm import tqdm
class EndpointHandler:
def __init__(self, path=""):
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained("tau/tavbert-he")
dir_model_config = os.path.join("models", "config.yml")
self.config = ModelConfig.load_from_file(dir_model_config)
self.model = DNikudModel(
self.config,
len(Nikud.label_2_id["nikud"]),
len(Nikud.label_2_id["dagesh"]),
len(Nikud.label_2_id["sin"]),
device=self.DEVICE,
).to(self.DEVICE)
state_dict_model = self.model.state_dict()
state_dict_model.update(torch.load("./models/Dnikud_best_model.pth"))
self.model.load_state_dict(state_dict_model)
self.max_length = MAX_LENGTH_SEN
def back_2_text(self, labels, text):
nikud = Nikud()
new_line = ""
for indx_char, c in enumerate(text):
new_line += (
c
+ nikud.id_2_char(labels[indx_char][1][1], "dagesh")
+ nikud.id_2_char(labels[indx_char][1][2], "sin")
+ nikud.id_2_char(labels[indx_char][1][0], "nikud")
)
print(indx_char, c)
print(labels)
return new_line
def prepare_data(self, data, name="train"):
print("Data = ", data)
dataset = []
for index, (sentence, label) in tqdm(
enumerate(data), desc=f"Prepare data {name}"
):
encoded_sequence = self.tokenizer.encode_plus(
sentence,
add_special_tokens=True,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
label_lists = [
[letter.nikud, letter.dagesh, letter.sin] for letter in label
]
label = torch.tensor(
[
[
Nikud.PAD_OR_IRRELEVANT,
Nikud.PAD_OR_IRRELEVANT,
Nikud.PAD_OR_IRRELEVANT,
]
]
+ label_lists[: (self.max_length - 1)]
+ [
[
Nikud.PAD_OR_IRRELEVANT,
Nikud.PAD_OR_IRRELEVANT,
Nikud.PAD_OR_IRRELEVANT,
]
for i in range(self.max_length - len(label) - 1)
]
)
dataset.append(
(
encoded_sequence["input_ids"][0],
encoded_sequence["attention_mask"][0],
label,
)
)
self.prepered_data = dataset
def predict_single_text(
self,
text,
):
dataset = NikudDataset(tokenizer=self.tokenizer, max_length=MAX_LENGTH_SEN)
data, orig_data = dataset.read_single_text(text)
print("data", data, len(data))
dataset.prepare_data(name="inference")
mtb_prediction_dl = torch.utils.data.DataLoader(
dataset.prepered_data, batch_size=BATCH_SIZE
)
# print("dataset", dataset, len(dataset))
# data = self.tokenizer(text, return_tensors="pt")
all_labels = predict(self.model, mtb_prediction_dl, self.DEVICE)
text_data_with_labels = dataset.back_2_text(labels=all_labels)
# all_labels = predict_single(self.model, dataset, self.DEVICE)
return text_data_with_labels
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
"""
# get inputs
inputs = data.pop("text", data)
# run normal prediction
prediction = self.predict_single_text(inputs)
# result = []
# for pred in prediction:
# result.append(self.back_2_text(pred, inputs))
# result = self.back_2_text(prediction, inputs)
return prediction