File size: 3,795 Bytes
5060c83 1900fa9 5060c83 5146354 e7ae272 5060c83 b637639 5060c83 b637639 5060c83 b637639 5060c83 ccad832 5060c83 b637639 ccad832 b637639 ef8da4c b637639 a2edfb0 b637639 ef8da4c a2edfb0 b637639 ccad832 ef8da4c ccad832 ef8da4c ccad832 a2edfb0 ccad832 ef8da4c a2edfb0 ccad832 b637639 7cbebc1 7ec400f ef8da4c a2edfb0 7cbebc1 78da422 7cbebc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
---
language:
- en
library_name: transformers
tags:
- cross-encoder
- search
- product-search
base_model: cross-encoder/ms-marco-MiniLM-L-12-v2
model-index:
- name: esci-ms-marco-MiniLM-L-12-v2
results:
- task:
type: text-classification
metrics:
- type: mrr@10
value: 91.81
- type: ndcg@10
value: 85.46
---
# Model Descripton
<!-- Provide a quick summary of what the model is/does. -->
Fine tunes a cross encoder on the Amazon ESCI dataset.
# Usage
## Transformers
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch import no_grad
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"
queries = [
"adidas shoes",
"adidas shoes",
"girls sandals",
"backpacks",
"shoes",
"mustard sleeveless gown"
]
documents = [
'{"title": "Nike Air Max", "description": "The best shoes you can get, with air cushion", "brand": "Nike", "color": "black"}',
'{"title": "Adidas Ultraboost", "description": "The shoes that represent the world", "brand": "Adidas", "color": "white"}',
'{"title": "Womens sandals", "description": "Sandals: wide width 9", "brand": "Chacos", "color": "blue"}',
'{"title": "Girls surf backpack", "description": "The best backpack in town", "brand": "Roxy", "color": "pink"}',
'{"title": "Fresh watermelon", "description": "The best fruit in town, all you can eat", "brand": "Fruitsellers Inc.", "color": "green"}',
'{"title": "Floral yellow dress with frills and lace", "description": "Brighten up your summers with a gorgeous dress", "brand": "Dressmakers Inc.", "color": "bright yellow"}'
]
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(
queries,
documents,
padding=True,
truncation=True,
return_tensors="pt",
)
model.eval()
with no_grad():
scores = model(**inputs).logits.cpu().detach().numpy()
print(scores)
```
### Sentence Transformers
```python
from sentence_transformers import CrossEncoder
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2"
queries = [
"adidas shoes",
"adidas shoes",
"girls sandals",
"backpacks",
"shoes",
"mustard sleeveless gown"
]
documents = [
'{"title": "Nike Air Max", "description": "The best shoes you can get, with air cushion", "brand": "Nike", "color": "black"}',
'{"title": "Adidas Ultraboost", "description": "The shoes that represent the world", "brand": "Adidas", "color": "white"}',
'{"title": "Womens sandals", "description": "Sandals: wide width 9", "brand": "Chacos", "color": "blue"}',
'{"title": "Girls surf backpack", "description": "The best backpack in town", "brand": "Roxy", "color": "pink"}',
'{"title": "Fresh watermelon", "description": "The best fruit in town, all you can eat", "brand": "Fruitsellers Inc.", "color": "green"}',
'{"title": "Floral yellow dress with frills and lace", "description": "Brighten up your summers with a gorgeous dress", "brand": "Dressmakers Inc.", "color": "bright yellow"}'
]
model = CrossEncoder(model_name, max_length=512)
scores = model.predict([(q, d) for q, d in zip(queries, documents)])
print(scores)
```
```bash
[ 1.057739 1.6751697 1.039221 1.5969192 -0.8867093 0.5035825 ]
```
## Training
Trained using `CrossEntropyLoss` using `<query, document>` pairs with `grade` as the label.
```python
from sentence_transformers import InputExample
train_samples = [
InputExample(texts=["query 1", "document 1"], label=0.3),
InputExample(texts=["query 1", "document 2"], label=0.8),
InputExample(texts=["query 2", "document 2"], label=0.1),
]
```` |