fix pkv update for new transformers compatibility
Browse files- modeling_chatglm.py +13 -4
modeling_chatglm.py
CHANGED
@@ -14,6 +14,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm
|
|
14 |
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
15 |
from torch.nn.utils import skip_init
|
16 |
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
|
|
17 |
|
18 |
from transformers.modeling_outputs import (
|
19 |
BaseModelOutputWithPast,
|
@@ -45,6 +46,8 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
45 |
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
46 |
]
|
47 |
|
|
|
|
|
48 |
|
49 |
def default_init(cls, *args, **kwargs):
|
50 |
return cls(*args, **kwargs)
|
@@ -867,10 +870,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
867 |
is_encoder_decoder: bool = False,
|
868 |
standardize_cache_format: bool = False,
|
869 |
) -> Dict[str, Any]:
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
874 |
|
875 |
# update attention mask
|
876 |
if "attention_mask" in model_kwargs:
|
|
|
14 |
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
15 |
from torch.nn.utils import skip_init
|
16 |
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
17 |
+
import transformers
|
18 |
|
19 |
from transformers.modeling_outputs import (
|
20 |
BaseModelOutputWithPast,
|
|
|
46 |
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
47 |
]
|
48 |
|
49 |
+
is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
|
50 |
+
|
51 |
|
52 |
def default_init(cls, *args, **kwargs):
|
53 |
return cls(*args, **kwargs)
|
|
|
870 |
is_encoder_decoder: bool = False,
|
871 |
standardize_cache_format: bool = False,
|
872 |
) -> Dict[str, Any]:
|
873 |
+
if is_transformers_4_42_or_higher:
|
874 |
+
# update past_key_values
|
875 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
876 |
+
outputs, standardize_cache_format=standardize_cache_format
|
877 |
+
)[1]
|
878 |
+
else:
|
879 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
880 |
+
outputs, standardize_cache_format=standardize_cache_format
|
881 |
+
)
|
882 |
+
|
883 |
|
884 |
# update attention mask
|
885 |
if "attention_mask" in model_kwargs:
|