MatchPrePrintArticles / create_negative_samples.py
KNGCRIMSON's picture
app
b5cf002
raw
history blame
2.94 kB
from src.dataset.GoodDataset import *
from src.dataset.NegativeSampler import *
import argparse
import os
def main(config):
"""
Main function to process the dataset and save it as a CSV file.
Args:
config: Namespace object containing the script arguments.
"""
dataset = AugmentedDataset()
dataset.load(config.input)
sampler = NegativeSampler(dataset)
sampler.create_negative_samples(config)
print(custom_struct_to_df(dataset.negative_samples).head())
custom_struct_to_df(dataset.positive_samples).to_csv('./data/pos.csv', index=False)
custom_struct_to_df(dataset.negative_samples).to_csv('./data/neg.csv', index=False)
print(len(dataset.positive_samples))
print(len(dataset.negative_samples))
if __name__ == "__main__":
# Parse command-line arguments
from src.utils.io_utils import PROJECT_ROOT
parser = argparse.ArgumentParser(description="Generate and save a dataset based on the given configuration.")
parser.add_argument("-i", "--input", type=str, default=os.path.join(PROJECT_ROOT, "data/positive_samples.pkl"), help="Input file path to load the positive samples.")
parser.add_argument("-o", "--output", type=str, default=os.path.join(PROJECT_ROOT, "data/negative_samples.pkl"), help="Output file path to save the negative samples.")
parser.add_argument("-s", "--seed", type=int, default=42, help="Random seed for reproducibility.")
parser.add_argument("-r", "--random", action='store_true', help="Utilization of `sample_random`")
parser.add_argument("-f", "--fuzz_title", action='store_true', help="Utilization of `fuzz_title`")
parser.add_argument("-ra", "--replace_auth", action='store_true', help="Utilization of `sample_authors_overlap_random`")
parser.add_argument("-oa", "--overlap_auth", action='store_true', help="Utilization of `sample_authors_overlap`")
parser.add_argument("-ot", "--overlap_topic", action='store_true', help="Utilization of `sample_similar_topic`")
parser.add_argument("--factor_max", type=int, default=4, help="Maximum number of negative samples to generate per positive sample.")
parser.add_argument("--authors_to_consider", type=int, default=1, help="Number of authors to consider when overlapping authors.")
parser.add_argument("--overlapping_authors", type=int, default=1, help="Minimum number of overlapping authors required.")
parser.add_argument("--fuzz_count", type=int, default=-1, help="Number of words to replace when fuzzing titles.")
# Parse the arguments and pass to the main function
config = parser.parse_args()
if config.overlap_auth and config.overlap_topic:
parser.error("Only one of --overlap_auth and --overlap_topic can be used.")
if not (config.overlap_auth or config.overlap_topic or config.random):
parser.error("At least one of --overlap_auth, --overlap_topic, or --random must be specified.")
main(config)