Add support for float32
Browse files- modeling_chatglm.py +2 -2
modeling_chatglm.py
CHANGED
@@ -254,13 +254,13 @@ def attention_fn(
|
|
254 |
if not (attention_mask == 0).all():
|
255 |
# if auto-regressive, skip
|
256 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
257 |
-
|
258 |
attention_scores = attention_scores.float()
|
259 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
260 |
|
261 |
attention_probs = F.softmax(attention_scores, dim=-1)
|
262 |
|
263 |
-
attention_probs = attention_probs.
|
264 |
|
265 |
# =========================
|
266 |
# Context layer. [sq, b, hp]
|
|
|
254 |
if not (attention_mask == 0).all():
|
255 |
# if auto-regressive, skip
|
256 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
257 |
+
dtype = attention_scores.type()
|
258 |
attention_scores = attention_scores.float()
|
259 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
260 |
|
261 |
attention_probs = F.softmax(attention_scores, dim=-1)
|
262 |
|
263 |
+
attention_probs = attention_probs.type(dtype)
|
264 |
|
265 |
# =========================
|
266 |
# Context layer. [sq, b, hp]
|