Qwen
/

yangapku commited on
Commit
21ed8f8
1 Parent(s): 5255e1c

update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +11 -4
modeling_qwen.py CHANGED
@@ -31,7 +31,11 @@ try:
31
  except ImportError:
32
  rearrange = None
33
  from torch import nn
34
- from kernels.cpp_kernels import cache_autogptq_cuda_256
 
 
 
 
35
 
36
  SUPPORT_CUDA = torch.cuda.is_available()
37
  SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
@@ -293,7 +297,7 @@ class QWenAttention(nn.Module):
293
  device = query.device
294
  if self.use_cache_quantization:
295
  qk, qk_scale, qk_zero = key
296
- if self.use_cache_kernel:
297
  shape = query.shape[:-1] + (qk.shape[-2],)
298
  attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
299
  cache_autogptq_cuda_256.vecquant8matmul_batched_faster_old(
@@ -348,7 +352,7 @@ class QWenAttention(nn.Module):
348
 
349
  if self.use_cache_quantization:
350
  qv, qv_scale, qv_zero = value
351
- if self.use_cache_kernel:
352
  shape = attn_weights.shape[:-1] + (query.shape[-1],)
353
  attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
354
  cache_autogptq_cuda_256.vecquant8matmul_batched_column_compression_faster_old(
@@ -1021,7 +1025,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1021
  if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
1022
  config.use_flash_attn = False
1023
  if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
1024
- from kernels.cpp_kernels import cache_autogptq_cuda_256
 
 
 
1025
 
1026
  self.transformer = QWenModel(config)
1027
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
31
  except ImportError:
32
  rearrange = None
33
  from torch import nn
34
+
35
+ try:
36
+ from kernels.cpp_kernels import cache_autogptq_cuda_256
37
+ except ImportError:
38
+ cache_autogptq_cuda_256 = None
39
 
40
  SUPPORT_CUDA = torch.cuda.is_available()
41
  SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
 
297
  device = query.device
298
  if self.use_cache_quantization:
299
  qk, qk_scale, qk_zero = key
300
+ if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
301
  shape = query.shape[:-1] + (qk.shape[-2],)
302
  attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
303
  cache_autogptq_cuda_256.vecquant8matmul_batched_faster_old(
 
352
 
353
  if self.use_cache_quantization:
354
  qv, qv_scale, qv_zero = value
355
+ if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
356
  shape = attn_weights.shape[:-1] + (query.shape[-1],)
357
  attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
358
  cache_autogptq_cuda_256.vecquant8matmul_batched_column_compression_faster_old(
 
1025
  if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
1026
  config.use_flash_attn = False
1027
  if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
1028
+ try:
1029
+ from kernels.cpp_kernels import cache_autogptq_cuda_256
1030
+ except ImportError:
1031
+ cache_autogptq_cuda_256 = None
1032
 
1033
  self.transformer = QWenModel(config)
1034
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)