Update README.md
Browse files
README.md
CHANGED
@@ -41,7 +41,7 @@ model_output = model(**encoded_input)
|
|
41 |
Sequence embeddings can be produced as follows:
|
42 |
|
43 |
```python
|
44 |
-
def
|
45 |
mask = encoded_input['attention_mask'].float()
|
46 |
d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
|
47 |
# make sep token invisible
|
@@ -53,7 +53,7 @@ def sequence_embeddings(encoded_input, model_output):
|
|
53 |
sum_mask = torch.clamp(mask.sum(1), min=1e-9)
|
54 |
return sum_embeddings / sum_mask
|
55 |
|
56 |
-
seq_embeds =
|
57 |
```
|
58 |
|
59 |
### Fine-tune
|
|
|
41 |
Sequence embeddings can be produced as follows:
|
42 |
|
43 |
```python
|
44 |
+
def get_sequence_embeddings(encoded_input, model_output):
|
45 |
mask = encoded_input['attention_mask'].float()
|
46 |
d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
|
47 |
# make sep token invisible
|
|
|
53 |
sum_mask = torch.clamp(mask.sum(1), min=1e-9)
|
54 |
return sum_embeddings / sum_mask
|
55 |
|
56 |
+
seq_embeds = get_sequence_embeddings(encoded_input, model_output)
|
57 |
```
|
58 |
|
59 |
### Fine-tune
|