silver commited on
Commit
fb1c9f0
2 Parent(s): a23d6a1 4a9b711

Merge remote-tracking branch 'thu/main'

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +8 -5
modeling_chatglm.py CHANGED
@@ -5,6 +5,7 @@ import copy
5
  import os
6
  import warnings
7
  import re
 
8
 
9
  import torch
10
  import torch.utils.checkpoint
@@ -32,10 +33,12 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
32
  from .configuration_chatglm import ChatGLMConfig
33
 
34
  # flags required to enable jit fusion kernels
35
- torch._C._jit_set_profiling_mode(False)
36
- torch._C._jit_set_profiling_executor(False)
37
- torch._C._jit_override_can_fuse_on_cpu(True)
38
- torch._C._jit_override_can_fuse_on_gpu(True)
 
 
39
 
40
  logger = logging.get_logger(__name__)
41
 
@@ -267,7 +270,7 @@ def attention_fn(
267
  if not (attention_mask == 0).all():
268
  # if auto-regressive, skip
269
  attention_scores.masked_fill_(attention_mask, -10000.0)
270
- dtype = attention_scores.type()
271
  attention_scores = attention_scores.float()
272
  attention_scores = attention_scores * query_key_layer_scaling_coeff
273
 
 
5
  import os
6
  import warnings
7
  import re
8
+ import sys
9
 
10
  import torch
11
  import torch.utils.checkpoint
 
33
  from .configuration_chatglm import ChatGLMConfig
34
 
35
  # flags required to enable jit fusion kernels
36
+
37
+ if sys.platform != 'darwin':
38
+ torch._C._jit_set_profiling_mode(False)
39
+ torch._C._jit_set_profiling_executor(False)
40
+ torch._C._jit_override_can_fuse_on_cpu(True)
41
+ torch._C._jit_override_can_fuse_on_gpu(True)
42
 
43
  logger = logging.get_logger(__name__)
44
 
 
270
  if not (attention_mask == 0).all():
271
  # if auto-regressive, skip
272
  attention_scores.masked_fill_(attention_mask, -10000.0)
273
+ dtype = attention_scores.dtype
274
  attention_scores = attention_scores.float()
275
  attention_scores = attention_scores * query_key_layer_scaling_coeff
276