import tensorflow as tf def createDataset(targetCaptions, embeddings, batchSize, tokenizer, maxSeqLen=32, loopForever=True, shuffleSize=None, encoderDims=(1, 768)): def generatorFunc(): while True: embeddings.shuffle() for d in embeddings: key, textEmb = d['id'], d['embedding'] try: caption = targetCaptions[key]['caption_multi'] if (caption is None): continue textIds = tokenizer.encode(caption) seqLen = len(textIds) if (seqLen > maxSeqLen): continue padSize = maxSeqLen - len(textIds) textIds = textIds + [0] * padSize attMask = [1] * seqLen + [0] * padSize yield textIds, attMask, textEmb except: pass if (loopForever == False): break f = lambda x, y=tf.float32: tf.convert_to_tensor(x, y) def _parse_function(textIds, attMask, textEmb): textIDs, att = f(textIds, tf.int32), f(attMask) tEmb = f(textEmb) return (textIDs, att), tEmb[0] dataset = tf.data.Dataset.from_generator(generatorFunc, output_types=( tf.int32, tf.float32, tf.float32), output_shapes=( (maxSeqLen,), (maxSeqLen,), encoderDims), ) if (shuffleSize is not None): dataset = dataset.shuffle(shuffleSize) dataset = dataset.map(_parse_function).batch(batchSize) return dataset def createTrainingAndValidationDataset(trainEmbeddings, valEmbeddings, batchSize, tokenizer, targetCaptions, maxSeqLen=32, encoderDims=(1, 768)): valDataset = createDataset(targetCaptions, valEmbeddings, batchSize, tokenizer, loopForever=False, maxSeqLen=maxSeqLen, encoderDims=encoderDims) trainDataset = createDataset(targetCaptions, trainEmbeddings, batchSize, tokenizer, loopForever=True, maxSeqLen=maxSeqLen, encoderDims=encoderDims) return trainDataset, valDataset