|
|
|
|
|
def get_a_and_b_segments(sample, np_rng): |
|
"""Divide sample into a and b segments.""" |
|
|
|
|
|
n_sentences = len(sample) |
|
|
|
assert n_sentences > 1, 'make sure each sample has at least two sentences.' |
|
|
|
|
|
|
|
a_end = 1 |
|
if n_sentences >= 3: |
|
|
|
a_end = np_rng.randint(1, n_sentences) |
|
tokens_a = [] |
|
for j in range(a_end): |
|
tokens_a.extend(sample[j]) |
|
|
|
|
|
tokens_b = [] |
|
for j in range(a_end, n_sentences): |
|
tokens_b.extend(sample[j]) |
|
|
|
|
|
is_next_random = False |
|
if np_rng.random() < 0.5: |
|
is_next_random = True |
|
tokens_a, tokens_b = tokens_b, tokens_a |
|
|
|
return tokens_a, tokens_b, is_next_random |
|
|