File size: 2,354 Bytes
3bd5293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from multilingual_clip import Config_MCLIP
import tensorflow as tf
import transformers


class SentenceModel(tf.keras.Model):

    def __init__(self, modelBase, from_pt=True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.transformer = transformers.TFAutoModel.from_pretrained(modelBase, from_pt=from_pt)

    @tf.function
    def generateMeanPooledSentenceEmbs(self, input, training=False):
        output = self.transformer(input, training=training)
        hiddenStates = output['last_hidden_state']

        outAtt = tf.cast(input['attention_mask'], tf.float32)
        sampleLength = tf.reduce_sum(outAtt, axis=-1, keepdims=True)
        maskedEmbs = hiddenStates * tf.expand_dims(outAtt, axis=-1)
        return tf.reduce_sum(maskedEmbs, axis=1) / tf.cast(sampleLength, tf.float32)

    @tf.function
    def call(self, inputs, training=False, mask=None):
        return self.generateMeanPooledSentenceEmbs(inputs, training)


class SentenceModelWithLinearTransformation(SentenceModel):

    def __init__(self, modelBase, embeddingSize=640, *args, **kwargs):
        super().__init__(modelBase, *args, **kwargs)
        self.postTransformation = tf.keras.layers.Dense(embeddingSize, activation='linear', name='LinearTransformation')

    @tf.function
    def call(self, inputs, training=False, mask=None):
        return self.postTransformation(self.generateMeanPooledSentenceEmbs(inputs, training))


class MultiLingualCLIP(transformers.TFPreTrainedModel):
    config_class = Config_MCLIP.MCLIPConfig

    @property
    def dummy_inputs(self):
        return {'input_ids': tf.ones((4, 12), tf.int32),
                'attention_mask': tf.ones((4, 12), tf.int32)}

    @tf.function(
        input_signature=[
            tf.TensorSpec((None, None), tf.int32), tf.TensorSpec((None, None), tf.int32)
        ]
    )
    def serving(self, ids, att):
        output = self.call((ids, att))
        return self.serving_output(output)

    def serving_output(self, outputs):
        return outputs

    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.sentenceModel = SentenceModelWithLinearTransformation(config.modelBase, config.numDims)

    @tf.function
    def call(self, inputs, training=False, mask=None):
        return self.sentenceModel.call(inputs, training)