Spaces:
Sleeping
Sleeping
from src.dataset.GoodDataset import * | |
from src.dataset.NegativeSampler import * | |
import argparse | |
import os | |
def main(config): | |
""" | |
Main function to process the dataset and save it as a CSV file. | |
Args: | |
config: Namespace object containing the script arguments. | |
""" | |
dataset = AugmentedDataset() | |
dataset.load(config.input) | |
sampler = NegativeSampler(dataset) | |
sampler.create_negative_samples(config) | |
print(custom_struct_to_df(dataset.negative_samples).head()) | |
custom_struct_to_df(dataset.positive_samples).to_csv('./data/pos.csv', index=False) | |
custom_struct_to_df(dataset.negative_samples).to_csv('./data/neg.csv', index=False) | |
print(len(dataset.positive_samples)) | |
print(len(dataset.negative_samples)) | |
if __name__ == "__main__": | |
# Parse command-line arguments | |
from src.utils.io_utils import PROJECT_ROOT | |
parser = argparse.ArgumentParser(description="Generate and save a dataset based on the given configuration.") | |
parser.add_argument("-i", "--input", type=str, default=os.path.join(PROJECT_ROOT, "data/positive_samples.pkl"), help="Input file path to load the positive samples.") | |
parser.add_argument("-o", "--output", type=str, default=os.path.join(PROJECT_ROOT, "data/negative_samples.pkl"), help="Output file path to save the negative samples.") | |
parser.add_argument("-s", "--seed", type=int, default=42, help="Random seed for reproducibility.") | |
parser.add_argument("-r", "--random", action='store_true', help="Utilization of `sample_random`") | |
parser.add_argument("-f", "--fuzz_title", action='store_true', help="Utilization of `fuzz_title`") | |
parser.add_argument("-ra", "--replace_auth", action='store_true', help="Utilization of `sample_authors_overlap_random`") | |
parser.add_argument("-oa", "--overlap_auth", action='store_true', help="Utilization of `sample_authors_overlap`") | |
parser.add_argument("-ot", "--overlap_topic", action='store_true', help="Utilization of `sample_similar_topic`") | |
parser.add_argument("--factor_max", type=int, default=4, help="Maximum number of negative samples to generate per positive sample.") | |
parser.add_argument("--authors_to_consider", type=int, default=1, help="Number of authors to consider when overlapping authors.") | |
parser.add_argument("--overlapping_authors", type=int, default=1, help="Minimum number of overlapping authors required.") | |
parser.add_argument("--fuzz_count", type=int, default=-1, help="Number of words to replace when fuzzing titles.") | |
# Parse the arguments and pass to the main function | |
config = parser.parse_args() | |
if config.overlap_auth and config.overlap_topic: | |
parser.error("Only one of --overlap_auth and --overlap_topic can be used.") | |
if not (config.overlap_auth or config.overlap_topic or config.random): | |
parser.error("At least one of --overlap_auth, --overlap_topic, or --random must be specified.") | |
main(config) | |