AdamLucek's picture
Update README.md
075ded1 verified
|
raw
history blame
2.39 kB
metadata
license: mit
datasets:
  - AdamLucek/apple-environmental-report-QA-retrieval
base_model: sentence-transformers/all-MiniLM-L6-v2
pipeline_tag: feature-extraction
library_name: peft

all-MiniLM-L6-v2-query-only-linear-adapter-AppleQA

Query-only linear adapter for sentence-transformers/all-MiniLM-L6-v2 with the AdamLucek/apple-environmental-report-QA-retrieval dataset.

6 adapters trained at 10, 20, 30, and 40 epochs with:

  • Triplet Margin Loss, Margin=1.0, Euclidean Distance=2
  • AdamW Optimizer
  • Random negative sampling from irrelevant document
  • LR: 0.003
  • Batch size: 32
  • Grad Norm: 1.0
  • Warmup Steps: 100

Training script and model creation available on Github Repo

Assessment

Baseline Hit Rate @10: 61.860%
Baseline Reciprocal Rank @10: 0.31108 (Average Rank 3.2)

Best performing checkpoint at 30epochs
Average Hit Rate @10: 66.628%
Mean Reciprocal Rank @10: 0.33119 (Average Rank 3.0)

A 7.7% Improvement in hit rate and a 6.5% improvement in mean reciprocal rank against base embedding model.

Usage

import torch
from torch import nn
from sentence_transformers import SentenceTransformer

class LinearAdapter(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, input_dim)
    
    def forward(self, x):
        return self.linear(x)

# Load the base model
base_model = SentenceTransformer('all-MiniLM-L6-v2')

# Load Adapter
adapter = LinearAdapter(base_model.get_sentence_embedding_dimension())
adapter.load_state_dict(torch.load('adapters/linear_adapter_30epochs.pth'))

# Example function for encoding
def encode_query(query, base_model, adapter):
    device = next(adapter.parameters()).device
    query_emb = base_model.encode(query, convert_to_tensor=True).to(device)
    adapted_query_emb = adapter(query_emb)
    return adapted_query_emb.cpu().detach().numpy()

emb = encode_query("Hello", base_model, adapter)

print(emb[:5])

output

[-0.13122843  0.02912715  0.07466945  0.09387457  0.13010463]