Pringled commited on
Commit
b54da62
·
unverified ·
1 Parent(s): e9a1430

Updated code with SemHash

Browse files
Files changed (2) hide show
  1. app.py +114 -130
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,88 +1,55 @@
1
  import gradio as gr
2
  from datasets import load_dataset
3
- import numpy as np
4
- from model2vec import StaticModel
5
- from reach import Reach
6
  from difflib import ndiff
7
 
8
- # Load the model
9
- model = StaticModel.from_pretrained("minishlab/potion-base-8M")
 
 
10
 
11
  # Default parameters
12
  default_dataset_name = "ag_news"
13
- default_dataset1_split = "train" # Default for the first dataset is "train"
14
- default_dataset2_split = "test" # Default for the second dataset is "test"
15
  default_text_column = "text"
16
  default_threshold = 0.9
17
 
18
- def deduplicate_embeddings(
19
- embeddings_a: np.ndarray,
20
- embeddings_b: np.ndarray = None,
21
- threshold: float = 0.9,
22
- batch_size: int = 1024,
23
- progress=None
24
- ) -> tuple[np.ndarray, dict[int, int]]:
25
- """
26
- Deduplicate embeddings within one dataset or across two datasets.
27
-
28
- :param embeddings_a: Embeddings of Dataset 1.
29
- :param embeddings_b: Optional, embeddings of Dataset 2.
30
- :param threshold: Similarity threshold for deduplication.
31
- :param batch_size: Batch size for similarity computation.
32
- :param progress: Gradio progress tracker for feedback.
33
- :return: Deduplicated indices and a mapping of removed indices to their original counterparts.
34
- """
35
- if embeddings_b is None:
36
- reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
37
- duplicate_to_original = {}
38
- results = reach.nearest_neighbor_threshold(
39
- embeddings_a, threshold=threshold, batch_size=batch_size, show_progressbar=False
40
- )
41
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_a))):
42
- for sim_idx, _ in similar_items:
43
- sim_idx = int(sim_idx)
44
- if sim_idx != i and sim_idx not in duplicate_to_original:
45
- duplicate_to_original[sim_idx] = i
46
- deduplicated_indices = set(range(len(embeddings_a))) - set(duplicate_to_original.keys())
47
- return deduplicated_indices, duplicate_to_original
48
- else:
49
- reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
50
- duplicate_indices_in_b = []
51
- duplicate_to_original = {}
52
- results = reach.nearest_neighbor_threshold(
53
- embeddings_b, threshold=threshold, batch_size=batch_size, show_progressbar=False
54
- )
55
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_b))):
56
- if similar_items:
57
- duplicate_indices_in_b.append(i)
58
- duplicate_to_original[i] = int(similar_items[0][0])
59
- return duplicate_indices_in_b, duplicate_to_original
60
 
61
  def display_word_differences(x: str, y: str) -> str:
62
  """
63
  Display the word-level differences between two texts, formatted to avoid
64
  misinterpretation of Markdown syntax.
65
-
66
- :param x: First text.
67
- :param y: Second text.
68
- :return: A string showing word-level differences, wrapped in a code block.
69
  """
70
  diff = ndiff(x.split(), y.split())
71
  formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
72
  return f"```\n{formatted_diff}\n```"
73
 
74
- def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
75
- """
76
- Load texts from a specified dataset and split.
77
 
78
- :param dataset_name: Name of the dataset.
79
- :param dataset_split: Split of the dataset (e.g., 'train', 'validation', 'test').
80
- :param text_column: Name of the text column.
81
- :return: A list of texts from the dataset.
82
- """
83
  ds = load_dataset(dataset_name, split=dataset_split)
84
  return [example[text_column] for example in ds]
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def perform_deduplication(
87
  deduplication_type: str,
88
  dataset1_name: str,
@@ -95,93 +62,107 @@ def perform_deduplication(
95
  progress: gr.Progress = gr.Progress(track_tqdm=True)
96
  ):
97
  """
98
- Perform deduplication on one or two datasets based on the deduplication type.
99
-
100
- :param deduplication_type: 'Single dataset' or 'Cross-dataset'.
101
- :param dataset1_name: Name of the first dataset.
102
- :param dataset1_split: Split of the first dataset.
103
- :param dataset1_text_column: Text column of the first dataset.
104
- :param dataset2_name: Optional, name of the second dataset (for cross-dataset deduplication).
105
- :param dataset2_split: Optional, split of the second dataset.
106
- :param dataset2_text_column: Optional, text column of the second dataset.
107
- :param threshold: Similarity threshold for deduplication.
108
- :param progress: Gradio progress tracker.
109
- :return: Status updates and result text for the Gradio interface.
110
  """
111
  try:
112
  threshold = float(threshold)
113
 
114
- # Load and process Dataset 1
115
  yield "Loading Dataset 1...", ""
116
  texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
117
- yield "Computing embeddings for Dataset 1...", ""
118
- embeddings1 = model.encode(texts1, show_progressbar=True)
119
 
120
  if deduplication_type == "Single dataset":
121
- # Deduplicate within Dataset 1
122
- yield "Deduplicating within Dataset 1...", ""
123
- deduplicated_indices, duplicate_mapping = deduplicate_embeddings(
124
- embeddings1, threshold=threshold, progress=progress
125
- )
126
-
127
- num_duplicates = len(duplicate_mapping)
 
 
 
 
 
 
128
  result_text = (
129
- f"**Total documents:** {len(texts1)}\n\n"
130
  f"**Duplicates found:** {num_duplicates}\n\n"
131
- f"**Unique documents after deduplication:** {len(deduplicated_indices)}\n\n"
132
  + "-" * 50 + "\n\n"
133
  )
134
 
 
135
  if num_duplicates > 0:
136
  result_text += "**Example duplicates:**\n\n"
137
- for dup_idx, orig_idx in list(duplicate_mapping.items())[:5]:
138
- orig_text = texts1[orig_idx]
139
- dup_text = texts1[dup_idx]
140
- differences = display_word_differences(orig_text, dup_text)
141
- result_text += (
142
- f"**Original:**\n{orig_text}\n\n"
143
- f"**Duplicate:**\n{dup_text}\n\n"
144
- f"**Differences:**\n{differences}\n"
145
- + "-" * 50 + "\n\n"
146
- )
 
 
 
 
 
 
 
 
 
147
  else:
148
  result_text += "No duplicates found."
149
 
150
  yield "Deduplication completed.", result_text
151
 
152
  else:
153
- # Load and process Dataset 2
154
  yield "Loading Dataset 2...", ""
155
  texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
156
- yield "Computing embeddings for Dataset 2...", ""
157
- embeddings2 = model.encode(texts2, show_progressbar=True)
158
 
159
- # Deduplicate Dataset 2 against Dataset 1
160
- yield "Deduplicating Dataset 2 against Dataset 1...", ""
161
- duplicate_indices, duplicate_mapping = deduplicate_embeddings(
162
- embeddings1, embeddings_b=embeddings2, threshold=threshold, progress=progress
163
- )
 
 
 
 
 
164
 
165
- num_duplicates = len(duplicate_indices)
166
  result_text = (
167
- f"**Total documents in {dataset2_name}/{dataset2_split}:** {len(texts2)}\n\n"
168
  f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
169
- f"**Unique documents after deduplication:** {len(texts2) - num_duplicates}\n\n"
170
  + "-" * 50 + "\n\n"
171
  )
172
 
173
  if num_duplicates > 0:
174
  result_text += "**Example duplicates from Dataset 2:**\n\n"
175
- for idx in duplicate_indices[:5]:
176
- orig_text = texts1[duplicate_mapping[idx]]
177
- dup_text = texts2[idx]
178
- differences = display_word_differences(orig_text, dup_text)
179
- result_text += (
180
- f"**Original (Dataset 1):**\n{orig_text}\n\n"
181
- f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
182
- f"**Differences:**\n{differences}\n"
183
- + "-" * 50 + "\n\n"
184
- )
 
 
 
 
 
 
 
 
185
  else:
186
  result_text += "No duplicates found."
187
 
@@ -191,44 +172,47 @@ def perform_deduplication(
191
  yield f"An error occurred: {e}", ""
192
  raise e
193
 
194
- # Gradio app with stop button support
 
195
  with gr.Blocks(theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }") as demo:
196
- gr.Markdown("# Semantic Deduplication")
197
  gr.Markdown("""
198
- This demo showcases semantic deduplication using Model2Vec for HuggingFace datasets.
199
- It can be used to identify duplicate texts within a single dataset or across two datasets.
200
- You can adjust the similarity threshold to control the strictness of the deduplication.\n
201
- NOTE: this demo runs on a free CPU backend, so it may be slow for large datasets. For faster results, please run the code locally.
 
 
202
  """)
203
 
204
  deduplication_type = gr.Radio(
205
- choices=["Cross-dataset", "Single dataset"], # Swapped "Cross-dataset" to the left
206
  label="Deduplication Type",
207
- value="Cross-dataset", # Set "Cross-dataset" as the default value
208
  )
209
 
210
  with gr.Row():
211
  dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
212
- dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split") # Default split is "train"
213
  dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
214
 
215
- dataset2_inputs = gr.Column(visible=True) # Make dataset2_inputs visible by default
216
  with dataset2_inputs:
217
  with gr.Row():
218
  dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
219
- dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split") # Default split is "test"
220
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
221
 
222
  threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
223
 
224
- with gr.Row(): # Placing the button in the same row for better alignment
225
  compute_button = gr.Button("Deduplicate")
226
 
227
  status_output = gr.Markdown(elem_id="status_output")
228
  result_output = gr.Markdown()
229
 
230
  def update_visibility(choice: str):
231
- return gr.update(visible=choice == "Cross-dataset")
232
 
233
  deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
234
 
@@ -247,5 +231,5 @@ with gr.Blocks(theme=gr.themes.Ocean(), css="#status_output { height: 50px; over
247
  outputs=[status_output, result_output],
248
  )
249
 
250
-
251
  demo.launch()
 
 
1
  import gradio as gr
2
  from datasets import load_dataset
 
 
 
3
  from difflib import ndiff
4
 
5
+ from semhash import SemHash
6
+ from semhash.datamodels import DeduplicationResult
7
+
8
+ from model2vec import StaticModel
9
 
10
  # Default parameters
11
  default_dataset_name = "ag_news"
12
+ default_dataset1_split = "train"
13
+ default_dataset2_split = "test"
14
  default_text_column = "text"
15
  default_threshold = 0.9
16
 
17
+ # Load the model to use
18
+ model = StaticModel.from_pretrained("minishlab/potion-base-8M")
19
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def display_word_differences(x: str, y: str) -> str:
22
  """
23
  Display the word-level differences between two texts, formatted to avoid
24
  misinterpretation of Markdown syntax.
 
 
 
 
25
  """
26
  diff = ndiff(x.split(), y.split())
27
  formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
28
  return f"```\n{formatted_diff}\n```"
29
 
 
 
 
30
 
31
+ def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
32
+ """Load texts from a specified dataset split."""
 
 
 
33
  ds = load_dataset(dataset_name, split=dataset_split)
34
  return [example[text_column] for example in ds]
35
 
36
+
37
+ def deduplicate_single_dataset(texts: list[str], threshold: float) -> DeduplicationResult:
38
+ """Deduplicate within a single dataset using SemHash, treating each text as a raw string record."""
39
+ # Build a SemHash index from the raw texts
40
+ semhash = SemHash.from_records(records=texts, model=model)
41
+ # Deduplicate the entire dataset
42
+ return semhash.self_deduplicate(threshold=threshold)
43
+
44
+
45
+ def deduplicate_two_datasets(texts1: list[str], texts2: list[str], threshold: float) -> DeduplicationResult:
46
+ """Deduplicate dataset2 against dataset1, both as raw strings, using SemHash."""
47
+ # Build SemHash index on dataset1
48
+ semhash = SemHash.from_records(records=texts1, model=model)
49
+ # Deduplicate texts2 against dataset1
50
+ return semhash.deduplicate(records=texts2, threshold=threshold)
51
+
52
+
53
  def perform_deduplication(
54
  deduplication_type: str,
55
  dataset1_name: str,
 
62
  progress: gr.Progress = gr.Progress(track_tqdm=True)
63
  ):
64
  """
65
+ Perform deduplication on one or two datasets using SemHash. This function
66
+ streams status updates to Gradio for user feedback.
 
 
 
 
 
 
 
 
 
 
67
  """
68
  try:
69
  threshold = float(threshold)
70
 
71
+ # Load Dataset 1
72
  yield "Loading Dataset 1...", ""
73
  texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
 
 
74
 
75
  if deduplication_type == "Single dataset":
76
+ # Single-dataset deduplication
77
+ yield "Deduplicating within Dataset 1 (SemHash)...", ""
78
+ result = deduplicate_single_dataset(texts1, threshold=threshold)
79
+
80
+ # Sort all duplicates in descending order of their highest score
81
+ for duprec in result.duplicates:
82
+ duprec.duplicates.sort(key=lambda x: x[1], reverse=True)
83
+
84
+ # Summarize results
85
+ num_duplicates = len(result.duplicates)
86
+ deduplicated_count = len(result.deduplicated)
87
+ total_docs = len(texts1)
88
+
89
  result_text = (
90
+ f"**Total documents (Dataset 1):** {total_docs}\n\n"
91
  f"**Duplicates found:** {num_duplicates}\n\n"
92
+ f"**Unique documents after deduplication:** {deduplicated_count}\n\n"
93
  + "-" * 50 + "\n\n"
94
  )
95
 
96
+ # Show example duplicates
97
  if num_duplicates > 0:
98
  result_text += "**Example duplicates:**\n\n"
99
+ for duprec in result.duplicates[:5]:
100
+ dup_text = duprec.record
101
+ if duprec.duplicates:
102
+ orig_text, score = duprec.duplicates[0]
103
+ differences = display_word_differences(orig_text, dup_text)
104
+ result_text += (
105
+ f"**Original:**\n{orig_text}\n\n"
106
+ f"**Duplicate:**\n{dup_text}\n\n"
107
+ f"**Similarity Score:** {score:.4f}\n"
108
+ f"**Differences:**\n{differences}\n"
109
+ + "-" * 50 + "\n\n"
110
+ )
111
+ else:
112
+ # Possibly an exact duplicate cluster
113
+ result_text += (
114
+ f"**Duplicate:**\n{dup_text}\n\n"
115
+ "No near-duplicate details available.\n"
116
+ + "-" * 50 + "\n\n"
117
+ )
118
  else:
119
  result_text += "No duplicates found."
120
 
121
  yield "Deduplication completed.", result_text
122
 
123
  else:
124
+ # Cross-dataset deduplication
125
  yield "Loading Dataset 2...", ""
126
  texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
 
 
127
 
128
+ yield "Deduplicating Dataset 2 against Dataset 1 (SemHash)...", ""
129
+ result = deduplicate_two_datasets(texts1, texts2, threshold=threshold)
130
+
131
+ # Sort duplicates in descending order of their highest score
132
+ for duprec in result.duplicates:
133
+ duprec.duplicates.sort(key=lambda x: x[1], reverse=True)
134
+
135
+ num_duplicates = len(result.duplicates)
136
+ total_docs2 = len(texts2)
137
+ deduplicated_count = len(result.deduplicated)
138
 
 
139
  result_text = (
140
+ f"**Total documents in {dataset2_name}/{dataset2_split}:** {total_docs2}\n\n"
141
  f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
142
+ f"**Unique documents after deduplication:** {deduplicated_count}\n\n"
143
  + "-" * 50 + "\n\n"
144
  )
145
 
146
  if num_duplicates > 0:
147
  result_text += "**Example duplicates from Dataset 2:**\n\n"
148
+ for duprec in result.duplicates[:5]:
149
+ dup_text = duprec.record # The "duplicate" text from dataset2
150
+ if duprec.duplicates:
151
+ orig_text, score = duprec.duplicates[0]
152
+ differences = display_word_differences(orig_text, dup_text)
153
+ result_text += (
154
+ f"**Original (Dataset 1):**\n{orig_text}\n\n"
155
+ f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
156
+ f"**Similarity Score:** {score:.4f}\n"
157
+ f"**Differences:**\n{differences}\n"
158
+ + "-" * 50 + "\n\n"
159
+ )
160
+ else:
161
+ result_text += (
162
+ f"**Potential Duplicate (Dataset 2):**\n{dup_text}\n\n"
163
+ "No near-duplicate details available.\n"
164
+ + "-" * 50 + "\n\n"
165
+ )
166
  else:
167
  result_text += "No duplicates found."
168
 
 
172
  yield f"An error occurred: {e}", ""
173
  raise e
174
 
175
+
176
+ # --- Gradio App ---
177
  with gr.Blocks(theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }") as demo:
178
+ gr.Markdown("# Semantic Text Deduplication Using SemHash")
179
  gr.Markdown("""
180
+ This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder.
181
+ It can be used to identify duplicate texts within a **single dataset** or across **two datasets**.
182
+ You can adjust the similarity threshold to control the strictness of the deduplication.
183
+
184
+ **NOTE**: This demo runs on a free CPU backend, so it may be slow for large datasets.
185
+ For faster results, please run the code locally.
186
  """)
187
 
188
  deduplication_type = gr.Radio(
189
+ choices=["Cross-dataset", "Single dataset"],
190
  label="Deduplication Type",
191
+ value="Cross-dataset", # default
192
  )
193
 
194
  with gr.Row():
195
  dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
196
+ dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
197
  dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
198
 
199
+ dataset2_inputs = gr.Column(visible=True)
200
  with dataset2_inputs:
201
  with gr.Row():
202
  dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
203
+ dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
204
  dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
205
 
206
  threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
207
 
208
+ with gr.Row():
209
  compute_button = gr.Button("Deduplicate")
210
 
211
  status_output = gr.Markdown(elem_id="status_output")
212
  result_output = gr.Markdown()
213
 
214
  def update_visibility(choice: str):
215
+ return gr.update(visible=(choice == "Cross-dataset"))
216
 
217
  deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
218
 
 
231
  outputs=[status_output, result_output],
232
  )
233
 
 
234
  demo.launch()
235
+
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
- reach < 5
2
- model2vec
3
  numpy
4
  datasets
5
  tqdm
 
1
+ semhash>=0.2.0
 
2
  numpy
3
  datasets
4
  tqdm