Update README.md
Browse files
README.md
CHANGED
@@ -20,18 +20,23 @@ language:
|
|
20 |
```python
|
21 |
from transformers import AutoTokenizer, AutoModel
|
22 |
import torch
|
|
|
23 |
|
24 |
# Sentences we want sentence embeddings for
|
25 |
sentences = ['This is an example sentence', 'Each sentence is converted']
|
26 |
|
27 |
# Load model from HuggingFace Hub
|
28 |
tokenizer = AutoTokenizer.from_pretrained('{MODEL_NAME}')
|
29 |
-
model = AutoModel.from_pretrained('{MODEL_NAME}')
|
30 |
|
31 |
-
|
|
|
|
|
|
|
32 |
|
33 |
# I used mean-pool method for sentence representation
|
34 |
with torch.no_grad():
|
|
|
35 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
36 |
representations, _ = self.model(**inputs, return_dict=False)
|
37 |
attention_mask = inputs["attention_mask"]
|
@@ -39,7 +44,9 @@ with torch.no_grad():
|
|
39 |
summed = torch.sum(representations * input_mask_expanded, 1)
|
40 |
sum_mask = input_mask_expanded.sum(1)
|
41 |
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
42 |
-
|
|
|
|
|
43 |
|
44 |
```
|
45 |
|
|
|
20 |
```python
|
21 |
from transformers import AutoTokenizer, AutoModel
|
22 |
import torch
|
23 |
+
device = torch.device('cuda')
|
24 |
|
25 |
# Sentences we want sentence embeddings for
|
26 |
sentences = ['This is an example sentence', 'Each sentence is converted']
|
27 |
|
28 |
# Load model from HuggingFace Hub
|
29 |
tokenizer = AutoTokenizer.from_pretrained('{MODEL_NAME}')
|
30 |
+
model = AutoModel.from_pretrained('{MODEL_NAME}').to(device)
|
31 |
|
32 |
+
tokenized_data = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
33 |
+
dataloader = DataLoader(tokenized_data, batch_size=batch_size, pin_memory=True)
|
34 |
+
all_outputs = torch.zeros((len(tokenized_data), self.hidden_size)).to(device)
|
35 |
+
start_idx = 0
|
36 |
|
37 |
# I used mean-pool method for sentence representation
|
38 |
with torch.no_grad():
|
39 |
+
for inputs in tqdm(dataloader):
|
40 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
41 |
representations, _ = self.model(**inputs, return_dict=False)
|
42 |
attention_mask = inputs["attention_mask"]
|
|
|
44 |
summed = torch.sum(representations * input_mask_expanded, 1)
|
45 |
sum_mask = input_mask_expanded.sum(1)
|
46 |
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
47 |
+
end_idx = start_idx + representations.shape[0]
|
48 |
+
all_outputs[start_idx:end_idx] = (summed / sum_mask)
|
49 |
+
start_idx = end_idx
|
50 |
|
51 |
```
|
52 |
|