ljw20180420's picture
Upload folder using huggingface_hub
f6fa04c verified
from transformers import PretrainedConfig, PreTrainedModel
import torch.nn as nn
import torch
import torch.nn.functional as F
class FOREcasTConfig(PretrainedConfig):
model_type = "FOREcasT"
label_names = ["count"]
def __init__(
self,
reg_const = 0.01, # regularization coefficient for insertion
i1_reg_const = 0.01, # regularization coefficient for deletion
seed = 63036, # random seed for intialization
**kwargs
):
self.reg_const = reg_const
self.i1_reg_const = i1_reg_const
self.seed = seed
super().__init__(**kwargs)
class FOREcasTModel(PreTrainedModel):
config_class = FOREcasTConfig
@staticmethod
def get_feature_label():
def features_pairwise_label(features1_label, features2_label):
features_label = []
for label1 in features1_label:
for label2 in features2_label:
features_label.append(f'PW_{label1}_vs_{label2}')
return features_label
feature_DelSize_label = ["Any Deletion", "D1", "D2-3", "D4-7", "D8-12", "D>12"]
feature_InsSize_label = ["Any Insertion", "I1", "I2"]
feature_DelLoc_label = ['DL-1--1', 'DL-2--2', 'DL-3--3', 'DL-4--6', 'DL-7--10', 'DL-11--15', 'DL-16--30', 'DL<-30', 'DL>=0', 'DR0-0', 'DR1-1', 'DR2-2', 'DR3-5', 'DR6-9', 'DR10-14', 'DR15-29', 'DR<0', 'DR>=30']
feature_InsSeq_label = ["I1_A", "I1_C", "I1_G", "I1_T", "I2_AA", "I2_AC", "I2_AG", "I2_AT", "I2_CA", "I2_CC", "I2_CG", "I2_CT", "I2_GA", "I2_GC", "I2_GG", "I2_GT", "I2_TA", "I2_TC", "I2_TG", "I2_TT"]
feature_InsLoc_label = ["IL-1--1", "IL-2--2", "IL-3--3", "IL<-3", "IL>=0"]
feature_LocalCutSiteSequence_label = []
for offset in range(-5, 4):
for nt in ["A", "G", "C", "T"]:
feature_LocalCutSiteSequence_label.append(f"CS{offset}_NT={nt}")
feature_LocalCutSiteSeqMatches_label = []
for offset1 in range(-3, 2):
for offset2 in range(-3, offset1):
for nt in ["A", "G", "C", "T"]:
feature_LocalCutSiteSeqMatches_label.append(f"M_CS{offset1}_{offset2}_NT={nt}")
feature_LocalRelativeSequence_label = []
for offset in range(-3, 3):
for nt in ["A", "G", "C", "T"]:
feature_LocalRelativeSequence_label.append(f'L{offset}_NT={nt}')
for offset in range(-3, 3):
for nt in ["A", "G", "C", "T"]:
feature_LocalRelativeSequence_label.append(f'R{offset}_NT={nt}')
feature_SeqMatches_label = []
for loffset in range(-3, 3):
for roffset in range(-3, 3):
feature_SeqMatches_label.append(f'X_L{loffset}_R{roffset}')
feature_SeqMatches_label.append(f'M_L{loffset}_R{roffset}')
feature_I1or2Rpt_label = ['I1Rpt', 'I1NonRpt', 'I2Rpt', 'I2NonRpt']
feature_microhomology_label = ['L_MH1-1', 'R_MH1-1', 'L_MH2-2', 'R_MH2-2', 'L_MH3-3', 'R_MH3-3', 'L_MM1_MH3-3', 'R_MM1_MH3-3', 'L_MH4-6', 'R_MH4-6', 'L_MM1_MH4-6', 'R_MM1_MH4-6', 'L_MH7-10', 'R_MH7-10', 'L_MM1_MH7-10', 'R_MM1_MH7-10', 'L_MH11-15', 'R_MH11-15', 'L_MM1_MH11-15', 'R_MM1_MH11-15', 'No MH']
return (
features_pairwise_label(feature_DelSize_label, feature_DelLoc_label) +
feature_InsSize_label +
feature_DelSize_label +
feature_DelLoc_label +
feature_InsLoc_label +
feature_InsSeq_label +
features_pairwise_label(feature_LocalCutSiteSequence_label, feature_InsSize_label + feature_DelSize_label) +
features_pairwise_label(feature_microhomology_label + feature_LocalRelativeSequence_label, feature_DelSize_label + feature_DelLoc_label) +
features_pairwise_label(feature_LocalCutSiteSeqMatches_label + feature_SeqMatches_label, feature_DelSize_label) +
features_pairwise_label(feature_InsSeq_label + feature_LocalCutSiteSequence_label + feature_LocalCutSiteSeqMatches_label, feature_I1or2Rpt_label) +
feature_I1or2Rpt_label +
feature_LocalCutSiteSequence_label +
feature_LocalCutSiteSeqMatches_label +
feature_LocalRelativeSequence_label +
feature_SeqMatches_label +
feature_microhomology_label
)
def __init__(self, config) -> None:
super().__init__(config)
# In more recent versions of PyTorch, you no longer need to explicitly register_parameter, it's enough to set a member of your nn.Module with nn.Parameter to "notify" pytorch that this variable should be treated as a trainable parameter (https://stackoverflow.com/questions/59234238/how-to-add-parameters-in-module-class-in-pytorch-custom-model).
self.generator = torch.Generator().manual_seed(config.seed)
is_delete = torch.tensor(['I' not in label for label in FOREcasTModel.get_feature_label()])
self.register_buffer('reg_coff', (is_delete * config.reg_const + ~is_delete * config.i1_reg_const))
self.linear = nn.Linear(in_features=len(self.reg_coff), out_features=1, bias=False)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0, std=1, generator=self.generator)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, feature, count=None) -> torch.Tensor:
logit = self.linear(feature).squeeze()
if count is not None:
return {
"logit": logit,
"loss": self.kl_divergence(logit, count)
}
return {"logit": logit}
def kl_divergence(self, logit, count):
return F.kl_div(
F.log_softmax(logit, dim=-1),
F.normalize(count + 0.5, p=1.0, dim=-1), # add 0.5 to prevent log(0), see loadOligoFeaturesAndReadCounts
reduction='sum'
) + logit.shape[0] * (self.reg_coff * (self.linear.weight ** 2)).sum()