import gradio as gr from datasets import load_dataset from difflib import ndiff from semhash import SemHash from semhash.datamodels import DeduplicationResult from model2vec import StaticModel # Default parameters default_dataset_name = "ag_news" default_dataset1_split = "train" default_dataset2_split = "test" default_text_column = "text" default_threshold = 0.9 # Load the model to use model = StaticModel.from_pretrained("minishlab/potion-base-8M") def display_word_differences(x: str, y: str) -> str: """ Display the word-level differences between two texts, formatted to avoid misinterpretation of Markdown syntax. """ diff = ndiff(x.split(), y.split()) formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-"))) return f"```\n{formatted_diff}\n```" def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]: """Load texts from a specified dataset split.""" ds = load_dataset(dataset_name, split=dataset_split) return [example[text_column] for example in ds] def deduplicate_single_dataset(texts: list[str], threshold: float) -> DeduplicationResult: """Deduplicate within a single dataset using SemHash, treating each text as a raw string record.""" # Build a SemHash index from the raw texts semhash = SemHash.from_records(records=texts, model=model) # Deduplicate the entire dataset return semhash.self_deduplicate(threshold=threshold) def deduplicate_two_datasets(texts1: list[str], texts2: list[str], threshold: float) -> DeduplicationResult: """Deduplicate dataset2 against dataset1, both as raw strings, using SemHash.""" # Build SemHash index on dataset1 semhash = SemHash.from_records(records=texts1, model=model) # Deduplicate texts2 against dataset1 return semhash.deduplicate(records=texts2, threshold=threshold) def perform_deduplication( deduplication_type: str, dataset1_name: str, dataset1_split: str, dataset1_text_column: str, dataset2_name: str = "", dataset2_split: str = "", dataset2_text_column: str = "", threshold: float = default_threshold, progress: gr.Progress = gr.Progress(track_tqdm=True) ): """ Perform deduplication on one or two datasets using SemHash. This function streams status updates to Gradio for user feedback. """ try: threshold = float(threshold) # Load Dataset 1 yield "Loading Dataset 1...", "" texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column) if deduplication_type == "Single dataset": # Single-dataset deduplication yield "Deduplicating within Dataset 1 (SemHash)...", "" result = deduplicate_single_dataset(texts1, threshold=threshold) # Sort all duplicates in descending order of their highest score for duprec in result.duplicates: duprec.duplicates.sort(key=lambda x: x[1], reverse=True) # Summarize results num_duplicates = len(result.duplicates) deduplicated_count = len(result.deduplicated) total_docs = len(texts1) result_text = ( f"**Total documents (Dataset 1):** {total_docs}\n\n" f"**Duplicates found:** {num_duplicates}\n\n" f"**Unique documents after deduplication:** {deduplicated_count}\n\n" + "-" * 50 + "\n\n" ) # Show example duplicates if num_duplicates > 0: result_text += "**Example duplicates:**\n\n" for duprec in result.duplicates[:5]: dup_text = duprec.record if duprec.duplicates: orig_text, score = duprec.duplicates[0] differences = display_word_differences(orig_text, dup_text) result_text += ( f"**Original:**\n{orig_text}\n\n" f"**Duplicate:**\n{dup_text}\n\n" f"**Similarity Score:** {score:.4f}\n" f"**Differences:**\n{differences}\n" + "-" * 50 + "\n\n" ) else: # Possibly an exact duplicate cluster result_text += ( f"**Duplicate:**\n{dup_text}\n\n" "No near-duplicate details available.\n" + "-" * 50 + "\n\n" ) else: result_text += "No duplicates found." yield "Deduplication completed.", result_text else: # Cross-dataset deduplication yield "Loading Dataset 2...", "" texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column) yield "Deduplicating Dataset 2 against Dataset 1 (SemHash)...", "" result = deduplicate_two_datasets(texts1, texts2, threshold=threshold) # Sort duplicates in descending order of their highest score for duprec in result.duplicates: duprec.duplicates.sort(key=lambda x: x[1], reverse=True) num_duplicates = len(result.duplicates) total_docs2 = len(texts2) deduplicated_count = len(result.deduplicated) result_text = ( f"**Total documents in {dataset2_name}/{dataset2_split}:** {total_docs2}\n\n" f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n" f"**Unique documents after deduplication:** {deduplicated_count}\n\n" + "-" * 50 + "\n\n" ) if num_duplicates > 0: result_text += "**Example duplicates from Dataset 2:**\n\n" for duprec in result.duplicates[:5]: dup_text = duprec.record # The "duplicate" text from dataset2 if duprec.duplicates: orig_text, score = duprec.duplicates[0] differences = display_word_differences(orig_text, dup_text) result_text += ( f"**Original (Dataset 1):**\n{orig_text}\n\n" f"**Duplicate (Dataset 2):**\n{dup_text}\n\n" f"**Similarity Score:** {score:.4f}\n" f"**Differences:**\n{differences}\n" + "-" * 50 + "\n\n" ) else: result_text += ( f"**Potential Duplicate (Dataset 2):**\n{dup_text}\n\n" "No near-duplicate details available.\n" + "-" * 50 + "\n\n" ) else: result_text += "No duplicates found." yield "Deduplication completed.", result_text except Exception as e: yield f"An error occurred: {e}", "" raise e # --- Gradio App --- with gr.Blocks(theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }") as demo: gr.Markdown("# Semantic Text Deduplication Using SemHash") gr.Markdown(""" This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder. It can be used to identify duplicate texts within a **single dataset** or across **two datasets**. You can adjust the similarity threshold to control the strictness of the deduplication. **NOTE**: This demo runs on a free CPU backend, so it may be slow for large datasets. For faster results, please run the code locally. """) deduplication_type = gr.Radio( choices=["Cross-dataset", "Single dataset"], label="Deduplication Type", value="Cross-dataset", # default ) with gr.Row(): dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name") dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split") dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name") dataset2_inputs = gr.Column(visible=True) with dataset2_inputs: with gr.Row(): dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name") dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split") dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name") threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold") with gr.Row(): compute_button = gr.Button("Deduplicate") status_output = gr.Markdown(elem_id="status_output") result_output = gr.Markdown() def update_visibility(choice: str): return gr.update(visible=(choice == "Cross-dataset")) deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs) compute_button.click( fn=perform_deduplication, inputs=[ deduplication_type, dataset1_name, dataset1_split, dataset1_text_column, dataset2_name, dataset2_split, dataset2_text_column, threshold, ], outputs=[status_output, result_output], ) demo.launch()