D_Nikud / src /models.py
NadavShaked's picture
Upload 7 files
91da6cc verified
raw
history blame
No virus
2.73 kB
# general
import subprocess
import yaml
# ML
import torch.nn as nn
from transformers import AutoConfig, RobertaForMaskedLM, PretrainedConfig
class DNikudModel(nn.Module):
def __init__(self, config, nikud_size, dagesh_size, sin_size, pretrain_model=None, device='cpu'):
super(DNikudModel, self).__init__()
if pretrain_model is not None:
model_base = RobertaForMaskedLM.from_pretrained(pretrain_model).to(device)
else:
model_base = RobertaForMaskedLM(config=config).to(device)
self.model = model_base.roberta
for name, param in self.model.named_parameters():
param.requires_grad = False
self.lstm1 = nn.LSTM(config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
self.lstm2 = nn.LSTM(2 * config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
self.out_n = nn.Linear(config.hidden_size, nikud_size)
self.out_d = nn.Linear(config.hidden_size, dagesh_size)
self.out_s = nn.Linear(config.hidden_size, sin_size)
def forward(self, input_ids, attention_mask):
last_hidden_state = self.model(input_ids, attention_mask=attention_mask).last_hidden_state
lstm1, _ = self.lstm1(last_hidden_state)
lstm2, _ = self.lstm2(lstm1)
dense = self.dense(lstm2)
nikud = self.out_n(dense)
dagesh = self.out_d(dense)
sin = self.out_s(dense)
return nikud, dagesh, sin
def get_git_commit_hash():
try:
commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
return commit_hash
except subprocess.CalledProcessError:
# This will be raised if you're not in a Git repository
print("Not inside a Git repository!")
return None
class ModelConfig(PretrainedConfig):
def __init__(self, max_length=None, dict=None):
super(ModelConfig, self).__init__()
if dict is None:
self.__dict__.update(AutoConfig.from_pretrained("tau/tavbert-he").__dict__)
self.max_length = max_length
self._commit_hash = get_git_commit_hash()
else:
self.__dict__.update(dict)
def print(self):
print(self.__dict__)
def save_to_file(self, file_path):
with open(file_path, "w") as yaml_file:
yaml.dump(self.__dict__, yaml_file, default_flow_style=False)
@classmethod
def load_from_file(cls, file_path):
with open(file_path, "r") as yaml_file:
config_dict = yaml.safe_load(yaml_file)
return cls(dict=config_dict)