Epik / Model /DeBERTa /deberta.py
Minh Q. Le
Added DeBERTa model from previous semester
133dc65
raw
history blame
7.41 kB
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers
import transformers
import os
MAX_LENGTH = 512 # the maximum number of messages per input
BATCH_SIZE = 8 # number of messages processed at a time
class MeanPool(tf.keras.layers.Layer):
def call(self, inputs, mask=None):
broadcast_mask = tf.expand_dims(tf.cast(mask, "float32"), -1)
embedding_sum = tf.reduce_sum(inputs * broadcast_mask, axis=1)
mask_sum = tf.reduce_sum(broadcast_mask, axis=1)
mask_sum = tf.math.maximum(mask_sum, tf.constant([1e-9]))
return embedding_sum / mask_sum
class WeightsSumOne(tf.keras.constraints.Constraint):
def __call__(self, w):
return tf.nn.softmax(w, axis=0)
def deberta_init(
pretrained_model_name: str = "microsoft/deberta-v3-large", tokenizer_dir: str = "."
):
"""Helper function to quickly initialize the config and tokenizer for a model
Args:
pretrained_model_name (str, optional): The model name. Defaults to "microsoft/deberta-v3-large".
tokenizer_dir (str, optional): Directory of the tokenizer. Defaults to ".".
Returns:
The configuration and tokenizer of the model.
"""
tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name)
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer")
tokenizer.save_pretrained(tokenizer_path)
cfg = transformers.AutoConfig.from_pretrained(
pretrained_model_name, output_hidden_states=True
)
cfg.hidden_dropout_prob = 0
cfg.attention_probs_dropout_prob = 0
cfg.save_pretrained(tokenizer_path)
return cfg, tokenizer
def get_model(cfg):
"""Get a DeBERTa model using the specified configuration
Args:
cfg : the configuration of the model (can be generated using deberta_init)
Returns:
The model with respect to the given configuration.
"""
input_ids = tf.keras.layers.Input(
shape=(MAX_LENGTH,), dtype=tf.int32, name="input_ids"
)
attention_masks = tf.keras.layers.Input(
shape=(MAX_LENGTH,), dtype=tf.int32, name="attention_masks"
)
deberta_model = transformers.TFAutoModel.from_pretrained(
"microsoft/deberta-v3-large", config=cfg
)
REINIT_LAYERS = 1
normal_initializer = tf.keras.initializers.GlorotUniform()
zeros_initializer = tf.keras.initializers.Zeros()
ones_initializer = tf.keras.initializers.Ones()
for encoder_block in deberta_model.deberta.encoder.layer[-REINIT_LAYERS:]:
for layer in encoder_block.submodules:
if isinstance(layer, tf.keras.layers.Dense):
layer.kernel.assign(
normal_initializer(
shape=layer.kernel.shape, dtype=layer.kernel.dtype
)
)
if layer.bias is not None:
layer.bias.assign(
zeros_initializer(
shape=layer.bias.shape, dtype=layer.bias.dtype
)
)
elif isinstance(layer, tf.keras.layers.LayerNormalization):
layer.beta.assign(
zeros_initializer(shape=layer.beta.shape, dtype=layer.beta.dtype)
)
layer.gamma.assign(
ones_initializer(shape=layer.gamma.shape, dtype=layer.gamma.dtype)
)
deberta_output = deberta_model.deberta(input_ids, attention_mask=attention_masks)
hidden_states = deberta_output.hidden_states
# WeightedLayerPool + MeanPool of the last 4 hidden states
stack_meanpool = tf.stack(
[MeanPool()(hidden_s, mask=attention_masks) for hidden_s in hidden_states[-4:]],
axis=2,
)
weighted_layer_pool = layers.Dense(
1, use_bias=False, kernel_constraint=WeightsSumOne()
)(stack_meanpool)
weighted_layer_pool = tf.squeeze(weighted_layer_pool, axis=-1)
output = layers.Dense(15, activation="linear")(weighted_layer_pool)
model = tf.keras.Model(inputs=[input_ids, attention_masks], outputs=output)
# Compile model with Layer-wise Learning Rate Decay
layer_list = [deberta_model.deberta.embeddings] + list(
deberta_model.deberta.encoder.layer
)
layer_list.reverse()
INIT_LR = 1e-5
LLRDR = 0.9
LR_SCH_DECAY_STEPS = 1600
lr_schedules = [
tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=INIT_LR * LLRDR**i,
decay_steps=LR_SCH_DECAY_STEPS,
decay_rate=0.3,
)
for i in range(len(layer_list))
]
lr_schedule_head = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=1e-4, decay_steps=LR_SCH_DECAY_STEPS, decay_rate=0.3
)
optimizers = [
tf.keras.optimizers.Adam(learning_rate=lr_sch) for lr_sch in lr_schedules
]
optimizers_and_layers = [
(tf.keras.optimizers.Adam(learning_rate=lr_schedule_head), model.layers[-4:])
] + list(zip(optimizers, layer_list))
optimizer = tfa.optimizers.MultiOptimizer(optimizers_and_layers)
model.compile(
optimizer=optimizer,
loss="mse",
metrics=[tf.keras.metrics.RootMeanSquaredError()],
)
return model
def deberta_encode(texts: str, tokenizer):
"""Helper function to tokenize the text using the specified tokenizer"""
input_ids = []
attention_mask = []
for text in texts:
token = tokenizer(
text,
add_special_tokens=True,
max_length=512,
return_attention_mask=True,
return_tensors="np",
truncation=True,
padding="max_length",
)
input_ids.append(token["input_ids"][0])
attention_mask.append(token["attention_mask"][0])
return np.array(input_ids, dtype="int32"), np.array(attention_mask, dtype="int32")
def predict(model, tokenizer, texts):
"""Predict the labels for each messages in texts
Args:
model: your DeBERTa model
tokenizer: a tokenizer (can be generated by deberta_init)
texts (_type_): _description_
Returns:
_type_: _description_
"""
prediction = model.predict(deberta_encode(texts, tokenizer))
labels = np.argmax(prediction, axis=1)
return labels
def load_model(cfg, model_dir: str = "."):
"""Helper function to load a DeBERTa model with pretrained weights
Args:
cfg: configuration for the model (can be generated with deberta_init)
model_dir (str, optional): the directory of the pretrained weights. Defaults to ".".
Returns:
A DeBERTa model with pretrained weights.
"""
tf.keras.backend.clear_session()
model = get_model(cfg)
model_path = os.path.join(model_dir, "best_model_fold2.h5")
model.load_weights(model_path)
return model
# map the integer labels to their original string representation
DEBERTA_LABEL_MAP = {
0: "Greeting",
1: "Curiosity",
2: "Interest",
3: "Obscene",
4: "Annoyed",
5: "Openness",
6: "Anxious",
7: "Acceptance",
8: "Uninterested",
9: "Informative",
10: "Accusatory",
11: "Denial",
12: "Confused",
13: "Disapproval",
14: "Remorse",
}
def decode_deberta_label(numeric_label):
return DEBERTA_LABEL_MAP.get(numeric_label, "Unknown Label")