|
import Dataset, TrainingModel |
|
import tensorflow as tf |
|
import transformers |
|
import datasets |
|
import Utils |
|
|
|
|
|
def loadTextTranslations(): |
|
return datasets.load_dataset('M-CLIP/ImageCaptions-7M-Translations')['train'] |
|
|
|
|
|
def loadTargetEmbeddings(imageBase="Vit-B-32", validationSize=5000): |
|
trainSamples = datasets.load_dataset('M-CLIP/ImageCaptions-7M-Embeddings', imageBase, |
|
split='train[{}:]'.format(validationSize)) |
|
valSamples = datasets.load_dataset('M-CLIP/ImageCaptions-7M-Embeddings', imageBase, |
|
split='train[:{}]'.format(validationSize)) |
|
|
|
embeddingShape = tf.convert_to_tensor(trainSamples[0]['embedding']).shape |
|
return trainSamples, valSamples, embeddingShape |
|
|
|
|
|
def singleGPUTraining(): |
|
numValidationSamples = 5000 |
|
stepsPerEpoch, lr = 1000, 0.00001 |
|
gradAccumSteps, batchSize = 1, 256 |
|
numTrainSteps, numWarmupSteps = 99999999, 1000 |
|
|
|
modelBase = 'xlm-roberta-large' |
|
tokenizerBase = 'xlm-roberta-large' |
|
imageBase = "Vit-B-32" |
|
modelName = '{}-{}'.format(modelBase, imageBase) |
|
|
|
startWeights = None |
|
targetCaptions = loadTextTranslations() |
|
trainEmbeddings, valEmbeddings, imageEncoderDimensions = loadTargetEmbeddings(validationSize=numValidationSamples) |
|
|
|
def createOptimizerFunc(): |
|
optimizer, schedule = transformers.optimization_tf.create_optimizer(lr, numTrainSteps, numWarmupSteps) |
|
if (gradAccumSteps <= 1): |
|
return optimizer |
|
else: |
|
return Utils.GradientAccumulator(optimizer, gradAccumSteps) |
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizerBase) |
|
model = TrainingModel.SentenceModelWithLinearTransformation(modelBase, imageEncoderDimensions[-1]) |
|
|
|
if (startWeights is not None): |
|
model.load_weights(startWeights) |
|
model.compile(createOptimizerFunc(), 'mse', metrics=['mae', 'cosine_similarity']) |
|
|
|
trainDataset, valDataset = Dataset.createTrainingAndValidationDataset(trainEmbeddings, valEmbeddings, batchSize, |
|
tokenizer, |
|
targetCaptions=targetCaptions, |
|
encoderDims=imageEncoderDimensions) |
|
|
|
if (gradAccumSteps > 1): |
|
stepsPerEpoch *= gradAccumSteps |
|
|
|
model.fit(trainDataset, epochs=1000, steps_per_epoch=stepsPerEpoch, |
|
validation_data=valDataset, |
|
callbacks=[ |
|
Utils.CustomSaveCallBack(modelName, saveInterval=5, firstSavePoint=5), |
|
] |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
strategy = tf.distribute.MirroredStrategy() |
|
with strategy.scope(): |
|
singleGPUTraining() |