Spaces:
Running
Running
import os | |
import json | |
import pandas as pd | |
import ast | |
import matplotlib.pyplot as plt | |
from matplotlib import rcParams | |
import argparse | |
import seaborn as sns | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
import numpy as np | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--output_dir", type=str, default="output") | |
parser.add_argument("--model", type=str, default=None) | |
parser.add_argument("--input_file", type=str, required=True) | |
parser.add_argument("--percentile", type=float, default=0.9999) | |
args = parser.parse_args() | |
output_dir = args.output_dir | |
input_file = args.input_file | |
with open(input_file) as f: | |
data = json.load(f) | |
os.makedirs(output_dir, exist_ok=True) | |
# Preprocessing | |
all_convs_new = [] | |
convs = [] | |
for row in data: | |
conv = "" | |
for turns in row["conversation_a"]: | |
if turns["role"] == "user": | |
conv += f"{turns['content']}\n" | |
convs.append(conv[:10000]) | |
row["post_process_conv"] = conv[:10000] | |
all_convs_new.append(row) | |
df = pd.DataFrame(all_convs_new) | |
print("Number of conversations: ", len(df)) | |
prompt_counts = df["post_process_conv"].value_counts() | |
# Select the top 20 most frequent prompts | |
top_prompts = prompt_counts.head(20) | |
print(top_prompts) | |
# Determine the percentile count | |
percentile_cutoff = prompt_counts.quantile(args.percentile) | |
print(f"{args.percentile*100} percentile count: {percentile_cutoff}") | |
# prompts that are more common than the percentile cutoff | |
high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index | |
print( | |
f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}" | |
) | |
# initialize a new column dedup_tag | |
dedup_tags = np.array( | |
[{"high_freq": False, "sampled": True} for _ in range(len(df))] | |
) | |
high_freq_groups = df.groupby("post_process_conv") | |
for prompt in tqdm(high_frequency_prompts): | |
df_high_freq = high_freq_groups.get_group(prompt) | |
sampled_indices = df_high_freq.sample( | |
n=int(percentile_cutoff), random_state=42 | |
).index | |
dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False} | |
dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True} | |
df["dedup_tag"] = dedup_tags | |
# drop intermediate columns (post_process_conv) | |
df = df.drop(columns=["post_process_conv"]) | |
df.to_json( | |
os.path.join(output_dir, "dedup.json"), | |
orient="records", | |
indent=4, | |
force_ascii=False, | |
) | |