File size: 2,267 Bytes
e489264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from functools import lru_cache
import json
import math
import re
from typing import List, Union
from pathlib import Path
import torch
from torch import Tensor


def load_text_file(
        file_path: Union[Path, str],
        encoding='utf-8',
        *args, **kwargs
        ) -> str:
    with open(file_path, 'r', encoding=encoding) as f:
        data = f.read()
    return data


def save_text_file(
        file_path: Union[Path, str],
        data: str,
        encoding='utf-8'
        ) -> str:
    with open(file_path, 'w', encoding=encoding) as f:
        data = f.write(data)
    return data


def remove_long_spaces(line: str) -> str:
    return re.sub('\s{2,}', ' ', line)


@lru_cache(maxsize=2)
def get_positionals(max_length: int, d_model: int) -> Tensor:
    """Create Positionals tensor to be added to the input
    Args:
        max_length (int): The maximum length of the positionals sequence.
        d_model (int): The dimensionality of the positionals sequence.
    Returns:
        Tensor: Positional tensor
    """
    result = torch.zeros(max_length, d_model, dtype=torch.float)
    for pos in range(max_length):
        for i in range(0, d_model, 2):
            denominator = pow(10000, 2 * i / d_model)
            result[pos, i] = math.sin(pos / denominator)
            result[pos, i + 1] = math.cos(pos / denominator)
    return result


def load_json(file_path: Union[Path, str]) -> Union[dict, list]:
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data


def save_json(
        file_path: Union[Path, str], data: Union[dict, list]
        ) -> None:
    with open(file_path, 'w') as f:
        json.dump(data, f)


def get_freq_dict(data: List[str]) -> dict:
    freq = {}
    for item in data:
        for word in item.split(' '):
            if word in freq:
                freq[word] += 1
            else:
                freq[word] = 1
    return freq


def load_state(state_path: Union[Path, str]):
    state = torch.load(state_path)
    model = state['model']
    model = {
        key.replace('module.', ''): value
        for key, value in model.items()
        }
    optimizer = state['optimizer']
    epoch = state['epoch']
    steps = state['steps']
    return model, optimizer, epoch, steps