medmac01
Added multilingual_clip module
3bd5293
raw
history blame
2.41 kB
import tensorflow as tf
def createDataset(targetCaptions, embeddings, batchSize, tokenizer, maxSeqLen=32, loopForever=True,
shuffleSize=None, encoderDims=(1, 768)):
def generatorFunc():
while True:
embeddings.shuffle()
for d in embeddings:
key, textEmb = d['id'], d['embedding']
try:
caption = targetCaptions[key]['caption_multi']
if (caption is None):
continue
textIds = tokenizer.encode(caption)
seqLen = len(textIds)
if (seqLen > maxSeqLen):
continue
padSize = maxSeqLen - len(textIds)
textIds = textIds + [0] * padSize
attMask = [1] * seqLen + [0] * padSize
yield textIds, attMask, textEmb
except:
pass
if (loopForever == False):
break
f = lambda x, y=tf.float32: tf.convert_to_tensor(x, y)
def _parse_function(textIds, attMask, textEmb):
textIDs, att = f(textIds, tf.int32), f(attMask)
tEmb = f(textEmb)
return (textIDs, att), tEmb[0]
dataset = tf.data.Dataset.from_generator(generatorFunc,
output_types=(
tf.int32, tf.float32, tf.float32),
output_shapes=(
(maxSeqLen,), (maxSeqLen,), encoderDims),
)
if (shuffleSize is not None):
dataset = dataset.shuffle(shuffleSize)
dataset = dataset.map(_parse_function).batch(batchSize)
return dataset
def createTrainingAndValidationDataset(trainEmbeddings, valEmbeddings, batchSize, tokenizer, targetCaptions,
maxSeqLen=32, encoderDims=(1, 768)):
valDataset = createDataset(targetCaptions, valEmbeddings, batchSize, tokenizer,
loopForever=False, maxSeqLen=maxSeqLen, encoderDims=encoderDims)
trainDataset = createDataset(targetCaptions, trainEmbeddings, batchSize, tokenizer,
loopForever=True, maxSeqLen=maxSeqLen, encoderDims=encoderDims)
return trainDataset, valDataset