danieldk HF staff commited on
Commit
20100e6
·
1 Parent(s): 132e594
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch25-cxx11-cu118-x86_64-linux/attention/__init__.py +21 -0
  2. build/torch25-cxx11-cu118-x86_64-linux/attention/_attention_6yvgebnqctora.abi3.so +3 -0
  3. build/torch25-cxx11-cu118-x86_64-linux/attention/_custom_ops.py +173 -0
  4. build/torch25-cxx11-cu118-x86_64-linux/attention/_ops.py +9 -0
  5. build/torch25-cxx11-cu118-x86_64-linux/attention/platforms.py +62 -0
  6. build/torch25-cxx11-cu121-x86_64-linux/attention/__init__.py +21 -0
  7. build/torch25-cxx11-cu121-x86_64-linux/attention/_attention_4jg2igd54wzge.abi3.so +3 -0
  8. build/torch25-cxx11-cu121-x86_64-linux/attention/_custom_ops.py +173 -0
  9. build/torch25-cxx11-cu121-x86_64-linux/attention/_ops.py +9 -0
  10. build/torch25-cxx11-cu121-x86_64-linux/attention/platforms.py +62 -0
  11. build/torch25-cxx11-cu124-x86_64-linux/attention/__init__.py +21 -0
  12. build/torch25-cxx11-cu124-x86_64-linux/attention/_attention_syg6kbhkhc4xk.abi3.so +3 -0
  13. build/torch25-cxx11-cu124-x86_64-linux/attention/_custom_ops.py +173 -0
  14. build/torch25-cxx11-cu124-x86_64-linux/attention/_ops.py +9 -0
  15. build/torch25-cxx11-cu124-x86_64-linux/attention/platforms.py +62 -0
  16. build/torch25-cxx98-cu118-x86_64-linux/attention/__init__.py +21 -0
  17. build/torch25-cxx98-cu118-x86_64-linux/attention/_attention_hhzgzhvc7zviy.abi3.so +3 -0
  18. build/torch25-cxx98-cu118-x86_64-linux/attention/_custom_ops.py +173 -0
  19. build/torch25-cxx98-cu118-x86_64-linux/attention/_ops.py +9 -0
  20. build/torch25-cxx98-cu118-x86_64-linux/attention/platforms.py +62 -0
  21. build/torch25-cxx98-cu121-x86_64-linux/attention/__init__.py +21 -0
  22. build/torch25-cxx98-cu121-x86_64-linux/attention/_attention_gbi5gm244waic.abi3.so +3 -0
  23. build/torch25-cxx98-cu121-x86_64-linux/attention/_custom_ops.py +173 -0
  24. build/torch25-cxx98-cu121-x86_64-linux/attention/_ops.py +9 -0
  25. build/torch25-cxx98-cu121-x86_64-linux/attention/platforms.py +62 -0
  26. build/torch25-cxx98-cu124-x86_64-linux/attention/__init__.py +21 -0
  27. build/torch25-cxx98-cu124-x86_64-linux/attention/_attention_ill75rmpj7yds.abi3.so +3 -0
  28. build/torch25-cxx98-cu124-x86_64-linux/attention/_custom_ops.py +173 -0
  29. build/torch25-cxx98-cu124-x86_64-linux/attention/_ops.py +9 -0
  30. build/torch25-cxx98-cu124-x86_64-linux/attention/platforms.py +62 -0
  31. build/torch26-cxx11-cu118-x86_64-linux/attention/__init__.py +21 -0
  32. build/torch26-cxx11-cu118-x86_64-linux/attention/_attention_6qe5ft3kiteru.abi3.so +3 -0
  33. build/torch26-cxx11-cu118-x86_64-linux/attention/_custom_ops.py +173 -0
  34. build/torch26-cxx11-cu118-x86_64-linux/attention/_ops.py +9 -0
  35. build/torch26-cxx11-cu118-x86_64-linux/attention/platforms.py +62 -0
  36. build/torch26-cxx11-cu124-x86_64-linux/attention/__init__.py +21 -0
  37. build/torch26-cxx11-cu124-x86_64-linux/attention/_attention_ftq3cjdxqfw4m.abi3.so +3 -0
  38. build/torch26-cxx11-cu124-x86_64-linux/attention/_custom_ops.py +173 -0
  39. build/torch26-cxx11-cu124-x86_64-linux/attention/_ops.py +9 -0
  40. build/torch26-cxx11-cu124-x86_64-linux/attention/platforms.py +62 -0
  41. build/torch26-cxx11-cu126-x86_64-linux/attention/__init__.py +21 -0
  42. build/torch26-cxx11-cu126-x86_64-linux/attention/_attention_lkibbjh726iwm.abi3.so +3 -0
  43. build/torch26-cxx11-cu126-x86_64-linux/attention/_custom_ops.py +173 -0
  44. build/torch26-cxx11-cu126-x86_64-linux/attention/_ops.py +9 -0
  45. build/torch26-cxx11-cu126-x86_64-linux/attention/platforms.py +62 -0
  46. build/torch26-cxx98-cu118-x86_64-linux/attention/__init__.py +21 -0
  47. build/torch26-cxx98-cu118-x86_64-linux/attention/_attention_vbhagz24hyij6.abi3.so +3 -0
  48. build/torch26-cxx98-cu118-x86_64-linux/attention/_custom_ops.py +173 -0
  49. build/torch26-cxx98-cu118-x86_64-linux/attention/_ops.py +9 -0
  50. build/torch26-cxx98-cu118-x86_64-linux/attention/platforms.py +62 -0
build/torch25-cxx11-cu118-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch25-cxx11-cu118-x86_64-linux/attention/_attention_6yvgebnqctora.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aee255dc2618e23d4e2076ff3d16c4fbd12d63742fde84252cfb6bfe55c5376e
3
+ size 78886392
build/torch25-cxx11-cu118-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch25-cxx11-cu118-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_6yvgebnqctora
3
+ ops = torch.ops._attention_6yvgebnqctora
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_6yvgebnqctora::{op_name}"
build/torch25-cxx11-cu118-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx11-cu121-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch25-cxx11-cu121-x86_64-linux/attention/_attention_4jg2igd54wzge.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22599ebe9d209fcc82068054caf39f93e6828bb3889b344e655fee50e7a98864
3
+ size 75398808
build/torch25-cxx11-cu121-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch25-cxx11-cu121-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_4jg2igd54wzge
3
+ ops = torch.ops._attention_4jg2igd54wzge
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_4jg2igd54wzge::{op_name}"
build/torch25-cxx11-cu121-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx11-cu124-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch25-cxx11-cu124-x86_64-linux/attention/_attention_syg6kbhkhc4xk.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42a3b2b450b7e284694e8e6d7398627b977d1e5da12bb79d93c6009c192922f9
3
+ size 75568320
build/torch25-cxx11-cu124-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch25-cxx11-cu124-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_syg6kbhkhc4xk
3
+ ops = torch.ops._attention_syg6kbhkhc4xk
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_syg6kbhkhc4xk::{op_name}"
build/torch25-cxx11-cu124-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx98-cu118-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch25-cxx98-cu118-x86_64-linux/attention/_attention_hhzgzhvc7zviy.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffad04fc3e82be818bafed25c1be1e9e6145f99eb0ef89ab87ef5ab8c8366f9b
3
+ size 78850608
build/torch25-cxx98-cu118-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch25-cxx98-cu118-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_hhzgzhvc7zviy
3
+ ops = torch.ops._attention_hhzgzhvc7zviy
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_hhzgzhvc7zviy::{op_name}"
build/torch25-cxx98-cu118-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx98-cu121-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch25-cxx98-cu121-x86_64-linux/attention/_attention_gbi5gm244waic.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ed1c9c4c080a10f7d7f8c18e8e96613020851f769a1bf5e2b92bf19b4e01fb6
3
+ size 75359216
build/torch25-cxx98-cu121-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch25-cxx98-cu121-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_gbi5gm244waic
3
+ ops = torch.ops._attention_gbi5gm244waic
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_gbi5gm244waic::{op_name}"
build/torch25-cxx98-cu121-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch25-cxx98-cu124-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch25-cxx98-cu124-x86_64-linux/attention/_attention_ill75rmpj7yds.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f263e022ef503e7fffcbc15ef59e515b84889d4c473b9113f3fea292725b9e37
3
+ size 75532912
build/torch25-cxx98-cu124-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch25-cxx98-cu124-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_ill75rmpj7yds
3
+ ops = torch.ops._attention_ill75rmpj7yds
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_ill75rmpj7yds::{op_name}"
build/torch25-cxx98-cu124-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch26-cxx11-cu118-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch26-cxx11-cu118-x86_64-linux/attention/_attention_6qe5ft3kiteru.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e66eca8e825e5cee2dc18c1235319a4e5b1372d843cab74660e8d94792e02f7c
3
+ size 78857896
build/torch26-cxx11-cu118-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch26-cxx11-cu118-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_6qe5ft3kiteru
3
+ ops = torch.ops._attention_6qe5ft3kiteru
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_6qe5ft3kiteru::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch26-cxx11-cu124-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch26-cxx11-cu124-x86_64-linux/attention/_attention_ftq3cjdxqfw4m.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:913ba8f5166dc4e84ed8a2da4b1dc44c178a93eeb16aae9782176fb089a459a7
3
+ size 75552112
build/torch26-cxx11-cu124-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch26-cxx11-cu124-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_ftq3cjdxqfw4m
3
+ ops = torch.ops._attention_ftq3cjdxqfw4m
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_ftq3cjdxqfw4m::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch26-cxx11-cu126-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch26-cxx11-cu126-x86_64-linux/attention/_attention_lkibbjh726iwm.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91380eebc7db2ff85f92e687d388055f210123bac602a6bc273172834bf49012
3
+ size 75376640
build/torch26-cxx11-cu126-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch26-cxx11-cu126-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_lkibbjh726iwm
3
+ ops = torch.ops._attention_lkibbjh726iwm
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_lkibbjh726iwm::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch26-cxx98-cu118-x86_64-linux/attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch26-cxx98-cu118-x86_64-linux/attention/_attention_vbhagz24hyij6.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3746697abeeb7f829661c0912ccb36a7f7bb16c1f9eb7f14b1ee5e52c93ec055
3
+ size 78830632
build/torch26-cxx98-cu118-x86_64-linux/attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch26-cxx98-cu118-x86_64-linux/attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _attention_vbhagz24hyij6
3
+ ops = torch.ops._attention_vbhagz24hyij6
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_attention_vbhagz24hyij6::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/attention/platforms.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+
12
+
13
+ class Platform(ABC):
14
+ @classmethod
15
+ def seed_everything(cls, seed: int) -> None:
16
+ """
17
+ Set the seed of each random module.
18
+ `torch.manual_seed` will set seed on all devices.
19
+
20
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
21
+ """
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+
26
+ @abstractmethod
27
+ def get_device_name(self, device_id: int = 0) -> str: ...
28
+
29
+ @abstractmethod
30
+ def is_cuda(self) -> bool: ...
31
+
32
+ @abstractmethod
33
+ def is_rocm(self) -> bool: ...
34
+
35
+
36
+ class CudaPlatform(Platform):
37
+ @classmethod
38
+ @lru_cache(maxsize=8)
39
+ def get_device_name(cls, device_id: int = 0) -> str:
40
+ return torch.cuda.get_device_name(0)
41
+
42
+ def is_cuda(self) -> bool:
43
+ return True
44
+
45
+ def is_rocm(self) -> bool:
46
+ return False
47
+
48
+
49
+ class RocmPlatform(Platform):
50
+ @classmethod
51
+ @lru_cache(maxsize=8)
52
+ def get_device_name(cls, device_id: int = 0) -> str:
53
+ return torch.cuda.get_device_name(device_id)
54
+
55
+ def is_cuda(self) -> bool:
56
+ return False
57
+
58
+ def is_rocm(self) -> bool:
59
+ return True
60
+
61
+
62
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()