update batch inference
Browse files- modeling_qwen.py +32 -20
modeling_qwen.py
CHANGED
@@ -35,6 +35,8 @@ from torch import nn
|
|
35 |
SUPPORT_CUDA = torch.cuda.is_available()
|
36 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
37 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
|
|
|
|
38 |
|
39 |
from .configuration_qwen import QWenConfig
|
40 |
from .qwen_generation_utils import (
|
@@ -186,7 +188,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
186 |
device=q.device,
|
187 |
)
|
188 |
|
189 |
-
if attention_mask is not None:
|
190 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
191 |
if q.size(0) == v.size(0):
|
192 |
q = q[indices_k]
|
@@ -222,7 +224,7 @@ class FlashSelfAttention(torch.nn.Module):
|
|
222 |
softmax_scale=self.softmax_scale,
|
223 |
causal=is_causal,
|
224 |
)
|
225 |
-
if attention_mask is not None and seqlen_q == seqlen_k:
|
226 |
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
|
227 |
else:
|
228 |
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
@@ -451,7 +453,7 @@ class QWenAttention(nn.Module):
|
|
451 |
def forward(
|
452 |
self,
|
453 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
454 |
-
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
455 |
registered_causal_mask: Optional[torch.Tensor] = None,
|
456 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
457 |
attention_mask: Optional[torch.FloatTensor] = None,
|
@@ -543,11 +545,7 @@ class QWenAttention(nn.Module):
|
|
543 |
and query.is_cuda
|
544 |
):
|
545 |
q, k, v = query, key, value
|
546 |
-
|
547 |
-
|
548 |
-
# b s h d -> b s (h d)
|
549 |
-
context_layer = context_layer.flatten(2,3).contiguous()
|
550 |
-
|
551 |
else:
|
552 |
query = query.permute(0, 2, 1, 3)
|
553 |
if not self.use_cache_quantization:
|
@@ -561,12 +559,28 @@ class QWenAttention(nn.Module):
|
|
561 |
and not query.is_cuda
|
562 |
):
|
563 |
raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
|
571 |
attn_output = self.c_proj(context_layer)
|
572 |
|
@@ -624,7 +638,7 @@ class QWenBlock(nn.Module):
|
|
624 |
def forward(
|
625 |
self,
|
626 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
627 |
-
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
628 |
registered_causal_mask: Optional[torch.Tensor] = None,
|
629 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
630 |
attention_mask: Optional[torch.FloatTensor] = None,
|
@@ -890,11 +904,9 @@ class QWenModel(QWenPreTrainedModel):
|
|
890 |
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
891 |
ntk_alpha_list.append(ntk_alpha)
|
892 |
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
|
897 |
-
rotary_pos_emb_list.append(rotary_pos_emb)
|
898 |
|
899 |
hidden_states = self.drop(hidden_states)
|
900 |
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
35 |
SUPPORT_CUDA = torch.cuda.is_available()
|
36 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
37 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
38 |
+
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
39 |
+
|
40 |
|
41 |
from .configuration_qwen import QWenConfig
|
42 |
from .qwen_generation_utils import (
|
|
|
188 |
device=q.device,
|
189 |
)
|
190 |
|
191 |
+
if batch_size > 1 and attention_mask is not None:
|
192 |
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
193 |
if q.size(0) == v.size(0):
|
194 |
q = q[indices_k]
|
|
|
224 |
softmax_scale=self.softmax_scale,
|
225 |
causal=is_causal,
|
226 |
)
|
227 |
+
if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
|
228 |
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
|
229 |
else:
|
230 |
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
|
|
453 |
def forward(
|
454 |
self,
|
455 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
456 |
+
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
457 |
registered_causal_mask: Optional[torch.Tensor] = None,
|
458 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
459 |
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
545 |
and query.is_cuda
|
546 |
):
|
547 |
q, k, v = query, key, value
|
548 |
+
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
549 |
else:
|
550 |
query = query.permute(0, 2, 1, 3)
|
551 |
if not self.use_cache_quantization:
|
|
|
559 |
and not query.is_cuda
|
560 |
):
|
561 |
raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
|
562 |
+
|
563 |
+
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
564 |
+
causal_mask = registered_causal_mask[
|
565 |
+
:, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2)
|
566 |
+
]
|
567 |
+
if attention_mask is not None:
|
568 |
+
attention_mask = attention_mask.expand(
|
569 |
+
-1, -1, causal_mask.size(2), -1
|
570 |
+
).masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
571 |
+
else:
|
572 |
+
attention_mask = causal_mask
|
573 |
+
attn_output = F.scaled_dot_product_attention(
|
574 |
+
query, key, value, attn_mask=attention_mask
|
575 |
+
).transpose(1, 2)
|
576 |
+
attn_weight = None
|
577 |
+
else:
|
578 |
+
attn_output, attn_weight = self._attn(
|
579 |
+
query, key, value, registered_causal_mask, attention_mask, head_mask
|
580 |
+
)
|
581 |
+
context_layer = self._merge_heads(
|
582 |
+
attn_output, self.num_heads, self.head_dim
|
583 |
+
)
|
584 |
|
585 |
attn_output = self.c_proj(context_layer)
|
586 |
|
|
|
638 |
def forward(
|
639 |
self,
|
640 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
641 |
+
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
642 |
registered_causal_mask: Optional[torch.Tensor] = None,
|
643 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
644 |
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
904 |
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
905 |
ntk_alpha_list.append(ntk_alpha)
|
906 |
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
907 |
+
rotary_pos_emb_list = [
|
908 |
+
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
909 |
+
]
|
|
|
|
|
910 |
|
911 |
hidden_states = self.drop(hidden_states)
|
912 |
output_shape = input_shape + (hidden_states.size(-1),)
|