File size: 2,730 Bytes
3bd5293 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import tensorflow as tf
import transformers
class SentenceModel(tf.keras.Model):
def __init__(self, modelBase, from_pt=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.transformer = transformers.TFAutoModel.from_pretrained(modelBase, from_pt=from_pt)
@tf.function
def generateSingleEmbedding(self, input, training=False):
inds, att = input
embs = self.transformer({'input_ids': inds, 'attention_mask': att}, training=training)[0]
outAtt = tf.cast(att, tf.float32)
sampleLength = tf.reduce_sum(outAtt, axis=-1, keepdims=True)
maskedEmbs = embs * tf.expand_dims(outAtt, axis=-1)
return tf.reduce_sum(maskedEmbs, axis=1) / tf.cast(sampleLength, tf.float32)
@tf.function
def generateMultipleEmbeddings(self, input, training=False):
inds, att = input
embs = self.transformer({'input_ids': inds, 'attention_mask': att}, training=training)['last_hidden_state']
print("Embs:", embs.shape)
outAtt = tf.cast(att, tf.float32)
sampleLength = tf.reduce_sum(outAtt, axis=-1, keepdims=True)
print("Att mask:", tf.expand_dims(outAtt, axis=-1).shape)
maskedEmbs = embs * tf.expand_dims(outAtt, axis=-1)
return tf.reduce_sum(maskedEmbs, axis=1) / tf.cast(sampleLength, tf.float32)
@tf.function
def call(self, inputs, training=False, mask=None):
return self.generateSingleEmbedding(inputs, training)
def save_pretrained(self, saveName):
self.transformer.save_pretrained(saveName)
def from_pretrained(self, saveName):
self.transformer = transformers.TFAutoModel.from_pretrained(saveName)
class SentenceModelWithLinearTransformation(SentenceModel):
def __init__(self, modelBase, embeddingSize=640, *args, **kwargs):
super().__init__(modelBase, *args, **kwargs)
self.postTransformation = tf.keras.layers.Dense(embeddingSize, activation='linear')
@tf.function
def call(self, inputs, training=False, mask=None):
return self.postTransformation(self.generateMultipleEmbeddings(inputs, training))
class SentenceModelWithTanHTransformation(SentenceModel):
def __init__(self, modelBase, embeddingSize=640, *args, **kwargs):
super().__init__(modelBase, *args, **kwargs)
self.postTransformation = tf.keras.layers.Dense(embeddingSize, activation='tanh')
self.postTransformation2 = tf.keras.layers.Dense(embeddingSize, activation='linear')
@tf.function
def call(self, inputs, training=False, mask=None):
meanEmbedding = self.generateSingleEmbedding(inputs, training)
d1 = self.postTransformation(meanEmbedding)
return self.postTransformation2(d1)
|