File size: 4,410 Bytes
91da6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Dict, List, Any
from transformers import AutoConfig, AutoTokenizer
from src.models import DNikudModel, ModelConfig
from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
from src.utiles_data import Nikud, NikudDataset
from src.models_utils import predict_single, predict
import torch
import os
from tqdm import tqdm


class EndpointHandler:
    def __init__(self, path=""):
        self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

        self.tokenizer = AutoTokenizer.from_pretrained("tau/tavbert-he")
        dir_model_config = os.path.join("models", "config.yml")
        self.config = ModelConfig.load_from_file(dir_model_config)
        self.model = DNikudModel(
            self.config,
            len(Nikud.label_2_id["nikud"]),
            len(Nikud.label_2_id["dagesh"]),
            len(Nikud.label_2_id["sin"]),
            device=self.DEVICE,
        ).to(self.DEVICE)
        state_dict_model = self.model.state_dict()
        state_dict_model.update(torch.load("./models/Dnikud_best_model.pth"))
        self.model.load_state_dict(state_dict_model)
        self.max_length = MAX_LENGTH_SEN

    def back_2_text(self, labels, text):
        nikud = Nikud()
        new_line = ""

        for indx_char, c in enumerate(text):
            new_line += (
                c
                + nikud.id_2_char(labels[indx_char][1][1], "dagesh")
                + nikud.id_2_char(labels[indx_char][1][2], "sin")
                + nikud.id_2_char(labels[indx_char][1][0], "nikud")
            )
            print(indx_char, c)
        print(labels)
        return new_line

    def prepare_data(self, data, name="train"):
        print("Data = ", data)
        dataset = []
        for index, (sentence, label) in tqdm(
            enumerate(data), desc=f"Prepare data {name}"
        ):
            encoded_sequence = self.tokenizer.encode_plus(
                sentence,
                add_special_tokens=True,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_attention_mask=True,
                return_tensors="pt",
            )
            label_lists = [
                [letter.nikud, letter.dagesh, letter.sin] for letter in label
            ]
            label = torch.tensor(
                [
                    [
                        Nikud.PAD_OR_IRRELEVANT,
                        Nikud.PAD_OR_IRRELEVANT,
                        Nikud.PAD_OR_IRRELEVANT,
                    ]
                ]
                + label_lists[: (self.max_length - 1)]
                + [
                    [
                        Nikud.PAD_OR_IRRELEVANT,
                        Nikud.PAD_OR_IRRELEVANT,
                        Nikud.PAD_OR_IRRELEVANT,
                    ]
                    for i in range(self.max_length - len(label) - 1)
                ]
            )

            dataset.append(
                (
                    encoded_sequence["input_ids"][0],
                    encoded_sequence["attention_mask"][0],
                    label,
                )
            )

        self.prepered_data = dataset

    def predict_single_text(
        self,
        text,
    ):
        dataset = NikudDataset(tokenizer=self.tokenizer, max_length=MAX_LENGTH_SEN)
        data, orig_data = dataset.read_single_text(text)
        print("data", data, len(data))
        dataset.prepare_data(name="inference")
        mtb_prediction_dl = torch.utils.data.DataLoader(
            dataset.prepered_data, batch_size=BATCH_SIZE
        )
        # print("dataset", dataset, len(dataset))
        # data = self.tokenizer(text, return_tensors="pt")
        all_labels = predict(self.model, mtb_prediction_dl, self.DEVICE)
        text_data_with_labels = dataset.back_2_text(labels=all_labels)
        # all_labels = predict_single(self.model, dataset, self.DEVICE)
        return text_data_with_labels

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
        """

        # get inputs
        inputs = data.pop("text", data)

        # run normal prediction
        prediction = self.predict_single_text(inputs)

        # result = []
        # for pred in prediction:
        #     result.append(self.back_2_text(pred, inputs))
        # result = self.back_2_text(prediction, inputs)
        return prediction