File size: 2,849 Bytes
3bd5293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import Dataset, TrainingModel
import tensorflow as tf
import transformers
import datasets
import Utils


def loadTextTranslations():
    return datasets.load_dataset('M-CLIP/ImageCaptions-7M-Translations')['train']


def loadTargetEmbeddings(imageBase="Vit-B-32", validationSize=5000):
    trainSamples = datasets.load_dataset('M-CLIP/ImageCaptions-7M-Embeddings', imageBase,
                                         split='train[{}:]'.format(validationSize))
    valSamples = datasets.load_dataset('M-CLIP/ImageCaptions-7M-Embeddings', imageBase,
                                       split='train[:{}]'.format(validationSize))

    embeddingShape = tf.convert_to_tensor(trainSamples[0]['embedding']).shape
    return trainSamples, valSamples, embeddingShape


def singleGPUTraining():
    numValidationSamples = 5000
    stepsPerEpoch, lr = 1000, 0.00001
    gradAccumSteps, batchSize = 1, 256
    numTrainSteps, numWarmupSteps = 99999999, 1000

    modelBase = 'xlm-roberta-large'
    tokenizerBase = 'xlm-roberta-large'
    imageBase = "Vit-B-32"
    modelName = '{}-{}'.format(modelBase, imageBase)

    startWeights = None
    targetCaptions = loadTextTranslations()
    trainEmbeddings, valEmbeddings, imageEncoderDimensions = loadTargetEmbeddings(validationSize=numValidationSamples)

    def createOptimizerFunc():
        optimizer, schedule = transformers.optimization_tf.create_optimizer(lr, numTrainSteps, numWarmupSteps)
        if (gradAccumSteps <= 1):
            return optimizer
        else:
            return Utils.GradientAccumulator(optimizer, gradAccumSteps)

    tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizerBase)
    model = TrainingModel.SentenceModelWithLinearTransformation(modelBase, imageEncoderDimensions[-1])

    if (startWeights is not None):
        model.load_weights(startWeights)
    model.compile(createOptimizerFunc(), 'mse', metrics=['mae', 'cosine_similarity'])

    trainDataset, valDataset = Dataset.createTrainingAndValidationDataset(trainEmbeddings, valEmbeddings, batchSize,
                                                                          tokenizer,
                                                                          targetCaptions=targetCaptions,
                                                                          encoderDims=imageEncoderDimensions)

    if (gradAccumSteps > 1):  # In order to make fair logging on Wandb
        stepsPerEpoch *= gradAccumSteps

    model.fit(trainDataset, epochs=1000, steps_per_epoch=stepsPerEpoch,
              validation_data=valDataset,
              callbacks=[
                  Utils.CustomSaveCallBack(modelName, saveInterval=5, firstSavePoint=5),
              ]
              )


if __name__ == '__main__':
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        singleGPUTraining()