|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
## Generate training data |
|
``` |
|
# Function to convert dataframe to list of InputExample |
|
def df_to_input_examples(df): |
|
return [ |
|
InputExample(texts=[row['query'], |
|
row['document']], |
|
label=float(row['relevance_score'])) |
|
for _, row in df.iterrows() |
|
] |
|
|
|
train_samples = df_to_input_examples(train_df) |
|
val_samples = df_to_input_examples(val_df) |
|
|
|
# Create a DataLoader for training |
|
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=16) |
|
``` |
|
|
|
## Create Evaluator class |
|
``` |
|
# Custom evaluator for CrossEncoder |
|
class CrossEncoderEvaluator: |
|
def __init__(self, eval_samples): |
|
self.eval_samples = eval_samples |
|
|
|
def __call__(self, model, **kwargs): # Add **kwargs to catch extra arguments |
|
predictions = model.predict([[sample.texts[0], sample.texts[1]] for sample in self.eval_samples]) |
|
labels = [sample.label for sample in self.eval_samples] |
|
|
|
pearson_corr, _ = pearsonr(predictions, labels) |
|
spearman_corr, _ = spearmanr(predictions, labels) |
|
|
|
return (pearson_corr + spearman_corr) / 2 # Average of Pearson and Spearman correlations |
|
|
|
# Prepare the evaluator |
|
evaluator = CrossEncoderEvaluator(val_samples) |
|
``` |
|
|
|
## Train the model |
|
``` |
|
# Initialize the cross-encoder model |
|
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', num_labels=1) |
|
|
|
# Train the model |
|
model.fit( |
|
train_dataloader=train_dataloader, |
|
evaluator=evaluator, |
|
epochs=100, |
|
warmup_steps=100, |
|
evaluation_steps=500, |
|
output_path='fine_tuned_reranker' |
|
) |
|
``` |
|
|
|
## Usage |
|
``` |
|
# Load the fine-tuned reranker |
|
reranker_model = CrossEncoder('fine_tuned_reranker') |
|
|
|
def search_and_rerank(query, documents, top_k=10): |
|
# Prepare pairs for reranking |
|
pairs = [(query, doc) for doc in documents] |
|
|
|
# Rerank using fine-tuned cross-encoder |
|
rerank_scores = reranker_model.predict(pairs) |
|
|
|
# Sort results by reranker scores |
|
reranked_results = sorted( |
|
zip(documents, rerank_scores.tolist()), |
|
key=lambda x: x[1], reverse=True |
|
) |
|
|
|
return reranked_results |
|
|
|
query = "OPPO 8GB 128G" |
|
documents = [ |
|
"OPPO Reno11F 5G 8GB-256GB", |
|
"OPPO Reno11F 5G 8GB-32GB", |
|
"OPPO Reno11F 5G 16GB-128GB", |
|
"Samsung galaxy 128GB", |
|
"Samsung S24 128GB", |
|
# ... |
|
] |
|
|
|
start_time = time.time() |
|
results = search_and_rerank(query, documents, len(documents)-1) |
|
end_time = time.time() |
|
|
|
execution_time = (end_time - start_time)*1000 |
|
print(f"Execution time: {execution_time:.4f} mili seconds") |
|
|
|
print(f"Query: \t\t\t\t{query}") |
|
for res in results: |
|
print(f"Score: {res[-1]:.4f} | Document: {res[0]}") |
|
``` |
|
|
|
Credit goes to: [email protected] |