medmac01
Added multilingual_clip module
3bd5293
raw
history blame
2.35 kB
from multilingual_clip import Config_MCLIP
import tensorflow as tf
import transformers
class SentenceModel(tf.keras.Model):
def __init__(self, modelBase, from_pt=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.transformer = transformers.TFAutoModel.from_pretrained(modelBase, from_pt=from_pt)
@tf.function
def generateMeanPooledSentenceEmbs(self, input, training=False):
output = self.transformer(input, training=training)
hiddenStates = output['last_hidden_state']
outAtt = tf.cast(input['attention_mask'], tf.float32)
sampleLength = tf.reduce_sum(outAtt, axis=-1, keepdims=True)
maskedEmbs = hiddenStates * tf.expand_dims(outAtt, axis=-1)
return tf.reduce_sum(maskedEmbs, axis=1) / tf.cast(sampleLength, tf.float32)
@tf.function
def call(self, inputs, training=False, mask=None):
return self.generateMeanPooledSentenceEmbs(inputs, training)
class SentenceModelWithLinearTransformation(SentenceModel):
def __init__(self, modelBase, embeddingSize=640, *args, **kwargs):
super().__init__(modelBase, *args, **kwargs)
self.postTransformation = tf.keras.layers.Dense(embeddingSize, activation='linear', name='LinearTransformation')
@tf.function
def call(self, inputs, training=False, mask=None):
return self.postTransformation(self.generateMeanPooledSentenceEmbs(inputs, training))
class MultiLingualCLIP(transformers.TFPreTrainedModel):
config_class = Config_MCLIP.MCLIPConfig
@property
def dummy_inputs(self):
return {'input_ids': tf.ones((4, 12), tf.int32),
'attention_mask': tf.ones((4, 12), tf.int32)}
@tf.function(
input_signature=[
tf.TensorSpec((None, None), tf.int32), tf.TensorSpec((None, None), tf.int32)
]
)
def serving(self, ids, att):
output = self.call((ids, att))
return self.serving_output(output)
def serving_output(self, outputs):
return outputs
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.sentenceModel = SentenceModelWithLinearTransformation(config.modelBase, config.numDims)
@tf.function
def call(self, inputs, training=False, mask=None):
return self.sentenceModel.call(inputs, training)