|
import tensorflow as tf |
|
|
|
|
|
def createDataset(targetCaptions, embeddings, batchSize, tokenizer, maxSeqLen=32, loopForever=True, |
|
shuffleSize=None, encoderDims=(1, 768)): |
|
def generatorFunc(): |
|
while True: |
|
embeddings.shuffle() |
|
for d in embeddings: |
|
key, textEmb = d['id'], d['embedding'] |
|
try: |
|
caption = targetCaptions[key]['caption_multi'] |
|
if (caption is None): |
|
continue |
|
|
|
textIds = tokenizer.encode(caption) |
|
seqLen = len(textIds) |
|
if (seqLen > maxSeqLen): |
|
continue |
|
|
|
padSize = maxSeqLen - len(textIds) |
|
textIds = textIds + [0] * padSize |
|
attMask = [1] * seqLen + [0] * padSize |
|
yield textIds, attMask, textEmb |
|
except: |
|
pass |
|
|
|
if (loopForever == False): |
|
break |
|
|
|
f = lambda x, y=tf.float32: tf.convert_to_tensor(x, y) |
|
|
|
def _parse_function(textIds, attMask, textEmb): |
|
textIDs, att = f(textIds, tf.int32), f(attMask) |
|
tEmb = f(textEmb) |
|
return (textIDs, att), tEmb[0] |
|
|
|
dataset = tf.data.Dataset.from_generator(generatorFunc, |
|
output_types=( |
|
tf.int32, tf.float32, tf.float32), |
|
output_shapes=( |
|
(maxSeqLen,), (maxSeqLen,), encoderDims), |
|
) |
|
|
|
if (shuffleSize is not None): |
|
dataset = dataset.shuffle(shuffleSize) |
|
dataset = dataset.map(_parse_function).batch(batchSize) |
|
|
|
return dataset |
|
|
|
|
|
def createTrainingAndValidationDataset(trainEmbeddings, valEmbeddings, batchSize, tokenizer, targetCaptions, |
|
maxSeqLen=32, encoderDims=(1, 768)): |
|
valDataset = createDataset(targetCaptions, valEmbeddings, batchSize, tokenizer, |
|
loopForever=False, maxSeqLen=maxSeqLen, encoderDims=encoderDims) |
|
trainDataset = createDataset(targetCaptions, trainEmbeddings, batchSize, tokenizer, |
|
loopForever=True, maxSeqLen=maxSeqLen, encoderDims=encoderDims) |
|
|
|
return trainDataset, valDataset |
|
|