update modeling_qwen.py
Browse files- modeling_qwen.py +11 -4
modeling_qwen.py
CHANGED
@@ -31,7 +31,11 @@ try:
|
|
31 |
except ImportError:
|
32 |
rearrange = None
|
33 |
from torch import nn
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
|
36 |
SUPPORT_CUDA = torch.cuda.is_available()
|
37 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
@@ -293,7 +297,7 @@ class QWenAttention(nn.Module):
|
|
293 |
device = query.device
|
294 |
if self.use_cache_quantization:
|
295 |
qk, qk_scale, qk_zero = key
|
296 |
-
if self.use_cache_kernel:
|
297 |
shape = query.shape[:-1] + (qk.shape[-2],)
|
298 |
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
299 |
cache_autogptq_cuda_256.vecquant8matmul_batched_faster_old(
|
@@ -348,7 +352,7 @@ class QWenAttention(nn.Module):
|
|
348 |
|
349 |
if self.use_cache_quantization:
|
350 |
qv, qv_scale, qv_zero = value
|
351 |
-
if self.use_cache_kernel:
|
352 |
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
353 |
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
354 |
cache_autogptq_cuda_256.vecquant8matmul_batched_column_compression_faster_old(
|
@@ -1021,7 +1025,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
1021 |
if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
|
1022 |
config.use_flash_attn = False
|
1023 |
if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
|
1024 |
-
|
|
|
|
|
|
|
1025 |
|
1026 |
self.transformer = QWenModel(config)
|
1027 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
31 |
except ImportError:
|
32 |
rearrange = None
|
33 |
from torch import nn
|
34 |
+
|
35 |
+
try:
|
36 |
+
from kernels.cpp_kernels import cache_autogptq_cuda_256
|
37 |
+
except ImportError:
|
38 |
+
cache_autogptq_cuda_256 = None
|
39 |
|
40 |
SUPPORT_CUDA = torch.cuda.is_available()
|
41 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
|
|
297 |
device = query.device
|
298 |
if self.use_cache_quantization:
|
299 |
qk, qk_scale, qk_zero = key
|
300 |
+
if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
|
301 |
shape = query.shape[:-1] + (qk.shape[-2],)
|
302 |
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
303 |
cache_autogptq_cuda_256.vecquant8matmul_batched_faster_old(
|
|
|
352 |
|
353 |
if self.use_cache_quantization:
|
354 |
qv, qv_scale, qv_zero = value
|
355 |
+
if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
|
356 |
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
357 |
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
358 |
cache_autogptq_cuda_256.vecquant8matmul_batched_column_compression_faster_old(
|
|
|
1025 |
if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
|
1026 |
config.use_flash_attn = False
|
1027 |
if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
|
1028 |
+
try:
|
1029 |
+
from kernels.cpp_kernels import cache_autogptq_cuda_256
|
1030 |
+
except ImportError:
|
1031 |
+
cache_autogptq_cuda_256 = None
|
1032 |
|
1033 |
self.transformer = QWenModel(config)
|
1034 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|