|
--- |
|
license: mit |
|
datasets: |
|
- AdamLucek/apple-environmental-report-QA-retrieval |
|
base_model: sentence-transformers/all-MiniLM-L6-v2 |
|
pipeline_tag: feature-extraction |
|
library_name: peft |
|
--- |
|
# all-MiniLM-L6-v2-query-only-linear-adapter-AppleQA |
|
|
|
Query-only linear adapter for [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co./sentence-transformers/all-MiniLM-L6-v2) with the [AdamLucek/apple-environmental-report-QA-retrieval](https://huggingface.co./datasets/AdamLucek/apple-environmental-report-QA-retrieval) dataset. |
|
|
|
6 adapters trained at 10, 20, 30, and 40 epochs with: |
|
- Triplet Margin Loss, Margin=1.0, Euclidean Distance=2 |
|
- AdamW Optimizer |
|
- Random negative sampling from irrelevant document |
|
- LR: 0.003 |
|
- Batch size: 32 |
|
- Grad Norm: 1.0 |
|
- Warmup Steps: 100 |
|
|
|
Training script and model creation available on [Github Repo](https://github.com/ALucek/linear-adapter-embedding) |
|
|
|
# Assessment |
|
|
|
**Baseline Hit Rate @10**: 61.860% |
|
**Baseline Reciprocal Rank @10**: 0.31108 (Average Rank 3.2) |
|
|
|
*Best performing checkpoint at 30epochs* |
|
**Average Hit Rate @10**: 66.628% |
|
**Mean Reciprocal Rank @10**: 0.33119 (Average Rank 3.0) |
|
|
|
A **7.7% Improvement** in hit rate and a **6.5% improvement** in mean reciprocal rank against base embedding model. |
|
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/65ba68a15d2ef0a4b2c892b4/ZsbVzv81cn2XW24eqbicU.png" width=800> |
|
|
|
# Usage |
|
|
|
```python |
|
import torch |
|
from torch import nn |
|
from sentence_transformers import SentenceTransformer |
|
|
|
class LinearAdapter(nn.Module): |
|
def __init__(self, input_dim): |
|
super().__init__() |
|
self.linear = nn.Linear(input_dim, input_dim) |
|
|
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
# Load the base model |
|
base_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
# Load Adapter |
|
adapter = LinearAdapter(base_model.get_sentence_embedding_dimension()) |
|
adapter.load_state_dict(torch.load('adapters/linear_adapter_30epochs.pth')) |
|
|
|
# Example function for encoding |
|
def encode_query(query, base_model, adapter): |
|
device = next(adapter.parameters()).device |
|
query_emb = base_model.encode(query, convert_to_tensor=True).to(device) |
|
adapted_query_emb = adapter(query_emb) |
|
return adapted_query_emb.cpu().detach().numpy() |
|
|
|
emb = encode_query("Hello", base_model, adapter) |
|
|
|
print(emb[:5]) |
|
``` |
|
**output** |
|
|
|
``` |
|
[-0.13122843 0.02912715 0.07466945 0.09387457 0.13010463] |
|
``` |