OpenLID-v2 / scripts /sample_with_temperature.py
laurievb's picture
Upload scripts/sample_with_temperature.py with huggingface_hub
895b334 verified
"""samples with temperature, grouping by language code. assumes input files is sorted by language group"""
import argparse
import logging
import random
import sys
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("corpus_filepath", type=str, help="path to input corpus to sample")
parser.add_argument("linecounts_filepath", type=str, help="path to file containing line counts of input corpus (from 'uniq -c')")
return parser.parse_args()
# def count_lines(file):
# def blocks(files, size=65536):
# while True:
# b = files.read(size)
# if not b: break
# yield b
# with open(file, "r",encoding="utf-8",errors='ignore') as f:
# return (sum(bl.count("\n") for bl in blocks(f)))
def main():
logging.basicConfig(
level=logging.INFO,
filename='sampling.log',
filemode='w',
format='%(asctime)s %(levelname)s: %(message)s',
datefmt='%m/%d/%Y %I:%M:%S %p')
logger = logging.getLogger(__name__)
args = parse_args()
logger.info(f"creating counts lookup dict from {args.linecounts_filepath}")
with open(args.linecounts_filepath) as f:
total_raw_lines = 0
lc_lookup = dict()
for line in f:
count, lang = line.strip().split(' ')
count = int(count)
lc_lookup[lang] = {"raw_lines": count}
total_raw_lines += count
logger.info(f"lookup dict finished ({len(lc_lookup)} entries)")
logger.info(f"dataset contains {total_raw_lines} lines")
# calculate lines to keep with (((raw_lines_in_lang / total_line_count) ** 0.3) / total_proprotions) * total lines
# calculate proportions
logger.info("calculating sampling factors")
total_sampling_factors = 0
for lang in lc_lookup:
# we sample lines proportional to this so smaller langs are upsampled and larger langs are downsampled
sampling_factor = (lc_lookup[lang]['raw_lines'] / total_raw_lines) ** 0.3
lc_lookup[lang]["sampling_factor"] = sampling_factor
total_sampling_factors += sampling_factor
logger.info(f"sampling factor total is {total_sampling_factors}")
logger.info(f"calculating number of lines to sample")
total_lines_to_sample = 0
for lang in lc_lookup:
lines_to_sample = round(lc_lookup[lang]["sampling_factor"]/total_sampling_factors * total_raw_lines)
lc_lookup[lang]['lines_to_sample'] = lines_to_sample
total_lines_to_sample += lines_to_sample
prop_size_difference = abs((total_raw_lines - total_lines_to_sample)/total_lines_to_sample)
assert prop_size_difference < 0.01 # sense check that sampled corpus is right size
logger.info(
f"total raw lines is {total_raw_lines}, total sampled lines is {total_lines_to_sample} ({prop_size_difference:.3%} difference)")
# assume input file is sorted by group
logger.info(f"sampling from {args.corpus_filepath}")
with open(args.corpus_filepath, "r") as f:
single_lang_line_store = []
langcode = ""
while line := f.readline():
line = line.strip()
_, nextlang, _ = line.split('\t')
if langcode == nextlang or langcode == "": # same language
single_lang_line_store.append(line)
else: # language change, time to sample and write out
raw_lines_in_lang = len(single_lang_line_store)
assert raw_lines_in_lang == lc_lookup[langcode]["raw_lines"] # sanity check it's same data
num_lines_to_keep = lc_lookup[langcode]["lines_to_sample"]
logger.info(f"finished reading {langcode}: read in {raw_lines_in_lang}, writing {num_lines_to_keep}")
if raw_lines_in_lang > num_lines_to_keep:
sampled_lines_gc = (x for x in random.sample(single_lang_line_store, num_lines_to_keep))
else: # need to oversample, so now use sampling with replacement
sampled_lines_gc = (x for x in random.choices(single_lang_line_store, k=num_lines_to_keep))
for out in sampled_lines_gc:
sys.stdout.write(f"{out}\n")
logger.info(f"finished writing {langcode} to stdout, now collecting lines for {nextlang}")
single_lang_line_store = [line]
langcode = nextlang
logger.info("sampling complete!")
if __name__ == "__main__":
main()