Merge remote-tracking branch 'thu/main'
Browse files- modeling_chatglm.py +8 -5
modeling_chatglm.py
CHANGED
@@ -5,6 +5,7 @@ import copy
|
|
5 |
import os
|
6 |
import warnings
|
7 |
import re
|
|
|
8 |
|
9 |
import torch
|
10 |
import torch.utils.checkpoint
|
@@ -32,10 +33,12 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
|
|
32 |
from .configuration_chatglm import ChatGLMConfig
|
33 |
|
34 |
# flags required to enable jit fusion kernels
|
35 |
-
|
36 |
-
|
37 |
-
torch._C.
|
38 |
-
torch._C.
|
|
|
|
|
39 |
|
40 |
logger = logging.get_logger(__name__)
|
41 |
|
@@ -267,7 +270,7 @@ def attention_fn(
|
|
267 |
if not (attention_mask == 0).all():
|
268 |
# if auto-regressive, skip
|
269 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
270 |
-
dtype = attention_scores.
|
271 |
attention_scores = attention_scores.float()
|
272 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
273 |
|
|
|
5 |
import os
|
6 |
import warnings
|
7 |
import re
|
8 |
+
import sys
|
9 |
|
10 |
import torch
|
11 |
import torch.utils.checkpoint
|
|
|
33 |
from .configuration_chatglm import ChatGLMConfig
|
34 |
|
35 |
# flags required to enable jit fusion kernels
|
36 |
+
|
37 |
+
if sys.platform != 'darwin':
|
38 |
+
torch._C._jit_set_profiling_mode(False)
|
39 |
+
torch._C._jit_set_profiling_executor(False)
|
40 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
41 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
42 |
|
43 |
logger = logging.get_logger(__name__)
|
44 |
|
|
|
270 |
if not (attention_mask == 0).all():
|
271 |
# if auto-regressive, skip
|
272 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
273 |
+
dtype = attention_scores.dtype
|
274 |
attention_scores = attention_scores.float()
|
275 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
276 |
|