Upload utils.py
Browse files- eval/utils.py +56 -0
eval/utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from torch import Tensor
|
5 |
+
from typing import Mapping
|
6 |
+
|
7 |
+
|
8 |
+
def _setup_logger():
|
9 |
+
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
|
10 |
+
logger = logging.getLogger()
|
11 |
+
logger.setLevel(logging.INFO)
|
12 |
+
|
13 |
+
console_handler = logging.StreamHandler()
|
14 |
+
console_handler.setFormatter(log_format)
|
15 |
+
logger.handlers = [console_handler]
|
16 |
+
|
17 |
+
return logger
|
18 |
+
|
19 |
+
|
20 |
+
logger = _setup_logger()
|
21 |
+
|
22 |
+
|
23 |
+
def move_to_cuda(sample):
|
24 |
+
if len(sample) == 0:
|
25 |
+
return {}
|
26 |
+
|
27 |
+
def _move_to_cuda(maybe_tensor):
|
28 |
+
if torch.is_tensor(maybe_tensor):
|
29 |
+
return maybe_tensor.cuda(non_blocking=True)
|
30 |
+
elif isinstance(maybe_tensor, dict):
|
31 |
+
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
|
32 |
+
elif isinstance(maybe_tensor, list):
|
33 |
+
return [_move_to_cuda(x) for x in maybe_tensor]
|
34 |
+
elif isinstance(maybe_tensor, tuple):
|
35 |
+
return tuple([_move_to_cuda(x) for x in maybe_tensor])
|
36 |
+
elif isinstance(maybe_tensor, Mapping):
|
37 |
+
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
|
38 |
+
else:
|
39 |
+
return maybe_tensor
|
40 |
+
|
41 |
+
return _move_to_cuda(sample)
|
42 |
+
|
43 |
+
|
44 |
+
def pool(last_hidden_states: Tensor,
|
45 |
+
attention_mask: Tensor,
|
46 |
+
pool_type: str) -> Tensor:
|
47 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
48 |
+
|
49 |
+
if pool_type == "avg":
|
50 |
+
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
51 |
+
elif pool_type == "cls":
|
52 |
+
emb = last_hidden[:, 0]
|
53 |
+
else:
|
54 |
+
raise ValueError(f"pool_type {pool_type} not supported")
|
55 |
+
|
56 |
+
return emb
|