|
"""samples with temperature, grouping by language code. assumes input files is sorted by language group""" |
|
|
|
import argparse |
|
import logging |
|
import random |
|
import sys |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("corpus_filepath", type=str, help="path to input corpus to sample") |
|
parser.add_argument("linecounts_filepath", type=str, help="path to file containing line counts of input corpus (from 'uniq -c')") |
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
filename='sampling.log', |
|
filemode='w', |
|
format='%(asctime)s %(levelname)s: %(message)s', |
|
datefmt='%m/%d/%Y %I:%M:%S %p') |
|
logger = logging.getLogger(__name__) |
|
|
|
args = parse_args() |
|
|
|
logger.info(f"creating counts lookup dict from {args.linecounts_filepath}") |
|
with open(args.linecounts_filepath) as f: |
|
total_raw_lines = 0 |
|
lc_lookup = dict() |
|
for line in f: |
|
count, lang = line.strip().split(' ') |
|
count = int(count) |
|
lc_lookup[lang] = {"raw_lines": count} |
|
total_raw_lines += count |
|
|
|
logger.info(f"lookup dict finished ({len(lc_lookup)} entries)") |
|
logger.info(f"dataset contains {total_raw_lines} lines") |
|
|
|
|
|
|
|
|
|
logger.info("calculating sampling factors") |
|
total_sampling_factors = 0 |
|
for lang in lc_lookup: |
|
|
|
sampling_factor = (lc_lookup[lang]['raw_lines'] / total_raw_lines) ** 0.3 |
|
lc_lookup[lang]["sampling_factor"] = sampling_factor |
|
total_sampling_factors += sampling_factor |
|
|
|
logger.info(f"sampling factor total is {total_sampling_factors}") |
|
logger.info(f"calculating number of lines to sample") |
|
total_lines_to_sample = 0 |
|
for lang in lc_lookup: |
|
lines_to_sample = round(lc_lookup[lang]["sampling_factor"]/total_sampling_factors * total_raw_lines) |
|
lc_lookup[lang]['lines_to_sample'] = lines_to_sample |
|
total_lines_to_sample += lines_to_sample |
|
prop_size_difference = abs((total_raw_lines - total_lines_to_sample)/total_lines_to_sample) |
|
assert prop_size_difference < 0.01 |
|
logger.info( |
|
f"total raw lines is {total_raw_lines}, total sampled lines is {total_lines_to_sample} ({prop_size_difference:.3%} difference)") |
|
|
|
|
|
logger.info(f"sampling from {args.corpus_filepath}") |
|
with open(args.corpus_filepath, "r") as f: |
|
single_lang_line_store = [] |
|
langcode = "" |
|
while line := f.readline(): |
|
line = line.strip() |
|
_, nextlang, _ = line.split('\t') |
|
if langcode == nextlang or langcode == "": |
|
single_lang_line_store.append(line) |
|
else: |
|
raw_lines_in_lang = len(single_lang_line_store) |
|
assert raw_lines_in_lang == lc_lookup[langcode]["raw_lines"] |
|
num_lines_to_keep = lc_lookup[langcode]["lines_to_sample"] |
|
logger.info(f"finished reading {langcode}: read in {raw_lines_in_lang}, writing {num_lines_to_keep}") |
|
if raw_lines_in_lang > num_lines_to_keep: |
|
sampled_lines_gc = (x for x in random.sample(single_lang_line_store, num_lines_to_keep)) |
|
else: |
|
sampled_lines_gc = (x for x in random.choices(single_lang_line_store, k=num_lines_to_keep)) |
|
for out in sampled_lines_gc: |
|
sys.stdout.write(f"{out}\n") |
|
logger.info(f"finished writing {langcode} to stdout, now collecting lines for {nextlang}") |
|
single_lang_line_store = [line] |
|
langcode = nextlang |
|
logger.info("sampling complete!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |