File size: 2,412 Bytes
3bd5293 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import tensorflow as tf
def createDataset(targetCaptions, embeddings, batchSize, tokenizer, maxSeqLen=32, loopForever=True,
shuffleSize=None, encoderDims=(1, 768)):
def generatorFunc():
while True:
embeddings.shuffle()
for d in embeddings:
key, textEmb = d['id'], d['embedding']
try:
caption = targetCaptions[key]['caption_multi']
if (caption is None):
continue
textIds = tokenizer.encode(caption)
seqLen = len(textIds)
if (seqLen > maxSeqLen):
continue
padSize = maxSeqLen - len(textIds)
textIds = textIds + [0] * padSize
attMask = [1] * seqLen + [0] * padSize
yield textIds, attMask, textEmb
except:
pass
if (loopForever == False):
break
f = lambda x, y=tf.float32: tf.convert_to_tensor(x, y)
def _parse_function(textIds, attMask, textEmb):
textIDs, att = f(textIds, tf.int32), f(attMask)
tEmb = f(textEmb)
return (textIDs, att), tEmb[0]
dataset = tf.data.Dataset.from_generator(generatorFunc,
output_types=(
tf.int32, tf.float32, tf.float32),
output_shapes=(
(maxSeqLen,), (maxSeqLen,), encoderDims),
)
if (shuffleSize is not None):
dataset = dataset.shuffle(shuffleSize)
dataset = dataset.map(_parse_function).batch(batchSize)
return dataset
def createTrainingAndValidationDataset(trainEmbeddings, valEmbeddings, batchSize, tokenizer, targetCaptions,
maxSeqLen=32, encoderDims=(1, 768)):
valDataset = createDataset(targetCaptions, valEmbeddings, batchSize, tokenizer,
loopForever=False, maxSeqLen=maxSeqLen, encoderDims=encoderDims)
trainDataset = createDataset(targetCaptions, trainEmbeddings, batchSize, tokenizer,
loopForever=True, maxSeqLen=maxSeqLen, encoderDims=encoderDims)
return trainDataset, valDataset
|