refactor-task (#18)
Browse files- refactor: rename task type to task (fb3bde88dfd9f5d35368582c3840fd30439cbbf7)
- README.md +8 -8
- custom_st.py +6 -6
- modules.json +1 -1
README.md
CHANGED
@@ -21546,7 +21546,7 @@ Additionally, it features 5 [LoRA](https://arxiv.org/abs/2106.09685) adapters to
|
|
21546 |
|
21547 |
### Key Features:
|
21548 |
- **Extended Sequence Length:** Supports up to 8192 tokens with RoPE.
|
21549 |
-
- **Task-Specific Embedding:** Customize embeddings through the `
|
21550 |
- `retrieval.query`: Used for query embeddings in asymmetric retrieval tasks
|
21551 |
- `retrieval.passage`: Used for passage embeddings in asymmetric retrieval tasks
|
21552 |
- `separation`: Used for embeddings in clustering and re-ranking applications
|
@@ -21605,7 +21605,7 @@ model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code
|
|
21605 |
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
|
21606 |
|
21607 |
with torch.no_grad():
|
21608 |
-
model_output = model(**encoded_input,
|
21609 |
|
21610 |
embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
|
21611 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
@@ -21643,10 +21643,10 @@ texts = [
|
|
21643 |
"Folge dem weißen Kaninchen.", # German
|
21644 |
]
|
21645 |
|
21646 |
-
# When calling the `encode` function, you can choose a `
|
21647 |
# 'retrieval.query', 'retrieval.passage', 'separation', 'classification', 'text-matching'
|
21648 |
-
# Alternatively, you can choose not to pass a `
|
21649 |
-
embeddings = model.encode(texts,
|
21650 |
|
21651 |
# Compute similarities
|
21652 |
print(embeddings[0] @ embeddings[1].T)
|
@@ -21680,11 +21680,11 @@ from sentence_transformers import SentenceTransformer
|
|
21680 |
|
21681 |
model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
|
21682 |
|
21683 |
-
|
21684 |
embeddings = model.encode(
|
21685 |
["What is the weather like in Berlin today?"],
|
21686 |
-
|
21687 |
-
prompt_name=
|
21688 |
)
|
21689 |
```
|
21690 |
|
|
|
21546 |
|
21547 |
### Key Features:
|
21548 |
- **Extended Sequence Length:** Supports up to 8192 tokens with RoPE.
|
21549 |
+
- **Task-Specific Embedding:** Customize embeddings through the `task` argument with the following options:
|
21550 |
- `retrieval.query`: Used for query embeddings in asymmetric retrieval tasks
|
21551 |
- `retrieval.passage`: Used for passage embeddings in asymmetric retrieval tasks
|
21552 |
- `separation`: Used for embeddings in clustering and re-ranking applications
|
|
|
21605 |
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
|
21606 |
|
21607 |
with torch.no_grad():
|
21608 |
+
model_output = model(**encoded_input, task='retrieval.query')
|
21609 |
|
21610 |
embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
|
21611 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
|
21643 |
"Folge dem weißen Kaninchen.", # German
|
21644 |
]
|
21645 |
|
21646 |
+
# When calling the `encode` function, you can choose a `task` based on the use case:
|
21647 |
# 'retrieval.query', 'retrieval.passage', 'separation', 'classification', 'text-matching'
|
21648 |
+
# Alternatively, you can choose not to pass a `task`, and no specific LoRA adapter will be used.
|
21649 |
+
embeddings = model.encode(texts, task="text-matching")
|
21650 |
|
21651 |
# Compute similarities
|
21652 |
print(embeddings[0] @ embeddings[1].T)
|
|
|
21680 |
|
21681 |
model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
|
21682 |
|
21683 |
+
task = "retrieval.query"
|
21684 |
embeddings = model.encode(
|
21685 |
["What is the weather like in Berlin today?"],
|
21686 |
+
task=task,
|
21687 |
+
prompt_name=task,
|
21688 |
)
|
21689 |
```
|
21690 |
|
custom_st.py
CHANGED
@@ -91,19 +91,19 @@ class Transformer(nn.Module):
|
|
91 |
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
|
92 |
|
93 |
def forward(
|
94 |
-
self, features: Dict[str, torch.Tensor],
|
95 |
) -> Dict[str, torch.Tensor]:
|
96 |
"""Returns token_embeddings, cls_token"""
|
97 |
-
if
|
98 |
raise ValueError(
|
99 |
-
f"Unsupported task '{
|
100 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
101 |
-
f"Alternatively, don't pass the `
|
102 |
)
|
103 |
|
104 |
adapter_mask = None
|
105 |
-
if
|
106 |
-
task_id = self._adaptation_map[
|
107 |
num_examples = features['input_ids'].size(0)
|
108 |
adapter_mask = torch.full(
|
109 |
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
|
|
91 |
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
|
92 |
|
93 |
def forward(
|
94 |
+
self, features: Dict[str, torch.Tensor], task: Optional[str] = None
|
95 |
) -> Dict[str, torch.Tensor]:
|
96 |
"""Returns token_embeddings, cls_token"""
|
97 |
+
if task and task not in self._lora_adaptations:
|
98 |
raise ValueError(
|
99 |
+
f"Unsupported task '{task}'. "
|
100 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
101 |
+
f"Alternatively, don't pass the `task` argument to disable LoRA."
|
102 |
)
|
103 |
|
104 |
adapter_mask = None
|
105 |
+
if task:
|
106 |
+
task_id = self._adaptation_map[task]
|
107 |
num_examples = features['input_ids'].size(0)
|
108 |
adapter_mask = torch.full(
|
109 |
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
modules.json
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
"name": "0",
|
5 |
"path": "",
|
6 |
"type": "custom_st.Transformer",
|
7 |
-
"kwargs": ["
|
8 |
},
|
9 |
{
|
10 |
"idx": 1,
|
|
|
4 |
"name": "0",
|
5 |
"path": "",
|
6 |
"type": "custom_st.Transformer",
|
7 |
+
"kwargs": ["task"]
|
8 |
},
|
9 |
{
|
10 |
"idx": 1,
|