sync with the latest official code
Browse files- modeling_qwen.py +5 -7
modeling_qwen.py
CHANGED
@@ -520,11 +520,9 @@ class QWenAttention(nn.Module):
|
|
520 |
|
521 |
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
522 |
if attention_mask is not None:
|
523 |
-
attention_mask = attention_mask.expand(
|
524 |
-
-1, -1, causal_mask.size(2), -1
|
525 |
-
)
|
526 |
if causal_mask is not None:
|
527 |
-
attention_mask.
|
528 |
else:
|
529 |
attention_mask = causal_mask
|
530 |
attn_output = F.scaled_dot_product_attention(
|
@@ -1330,14 +1328,14 @@ def apply_rotary_pos_emb(t, freqs):
|
|
1330 |
t (tensor(batch_size, seq_len, n_head, head_dim)):
|
1331 |
the input embedding/hidden states
|
1332 |
freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
|
1333 |
-
the cached cos/sin position embeddings
|
1334 |
"""
|
1335 |
rot_dim = freqs[0].shape[-1]
|
1336 |
cos, sin = freqs
|
1337 |
t_float = t.float()
|
1338 |
if apply_rotary_emb_func is not None and t.is_cuda:
|
1339 |
-
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
1340 |
-
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
1341 |
# to the first rotary_dim of the input
|
1342 |
cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
1343 |
sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
|
|
520 |
|
521 |
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
522 |
if attention_mask is not None:
|
523 |
+
attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
|
|
|
|
|
524 |
if causal_mask is not None:
|
525 |
+
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
526 |
else:
|
527 |
attention_mask = causal_mask
|
528 |
attn_output = F.scaled_dot_product_attention(
|
|
|
1328 |
t (tensor(batch_size, seq_len, n_head, head_dim)):
|
1329 |
the input embedding/hidden states
|
1330 |
freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
|
1331 |
+
the cached cos/sin position embeddings
|
1332 |
"""
|
1333 |
rot_dim = freqs[0].shape[-1]
|
1334 |
cos, sin = freqs
|
1335 |
t_float = t.float()
|
1336 |
if apply_rotary_emb_func is not None and t.is_cuda:
|
1337 |
+
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
1338 |
+
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
1339 |
# to the first rotary_dim of the input
|
1340 |
cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
1341 |
sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|