smpanaro commited on
Commit
f519906
1 Parent(s): e97e085

Faster argmax by skipping logit concat

Browse files
Llama-3.2-1B-Instruct_chunk6.mlmodelc/analytics/coremldata.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:23aa1a8a2b6bca88beeecf08d6281cdeb43aad33eb02cbe10ebff7eede7ed329
3
  size 243
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13c3ef8de95508f3e6076553fb737642a123494d4403aef5ebda7ce0c9acc236
3
  size 243
Llama-3.2-1B-Instruct_chunk6.mlmodelc/coremldata.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2740a88eec733758ae31768866e8550009ce36eb5177a09938d42cb18a095d05
3
- size 311
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0f38714827a836f9d17e85d57fd6906b69f0bab1eefba12c1ade1e3eaf9715b
3
+ size 501
Llama-3.2-1B-Instruct_chunk6.mlmodelc/metadata.json CHANGED
@@ -7,10 +7,80 @@
7
  "hasShapeFlexibility" : "0",
8
  "isOptional" : "0",
9
  "dataType" : "Float16",
10
- "formattedType" : "MultiArray (Float16 1 × 64 × 128256)",
11
  "shortDescription" : "",
12
- "shape" : "[1, 64, 128256]",
13
- "name" : "logits",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  "type" : "MultiArray"
15
  }
16
  ],
@@ -19,7 +89,7 @@
19
  ],
20
  "specificationVersion" : 7,
21
  "mlProgramOperationTypeHistogram" : {
22
- "Concat" : 2,
23
  "Ios16.mul" : 2,
24
  "Squeeze" : 1,
25
  "Transpose" : 1,
@@ -58,7 +128,7 @@
58
  "type" : "MultiArray"
59
  }
60
  ],
61
- "generatedClassName" : "Llama_3_2_1B_Instruct_2024_10_10_23_56_41_chunk6",
62
  "method" : "predict"
63
  }
64
  ]
 
7
  "hasShapeFlexibility" : "0",
8
  "isOptional" : "0",
9
  "dataType" : "Float16",
10
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
11
  "shortDescription" : "",
12
+ "shape" : "[1, 64, 16384]",
13
+ "name" : "logits_0",
14
+ "type" : "MultiArray"
15
+ },
16
+ {
17
+ "hasShapeFlexibility" : "0",
18
+ "isOptional" : "0",
19
+ "dataType" : "Float16",
20
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
21
+ "shortDescription" : "",
22
+ "shape" : "[1, 64, 16384]",
23
+ "name" : "logits_1",
24
+ "type" : "MultiArray"
25
+ },
26
+ {
27
+ "hasShapeFlexibility" : "0",
28
+ "isOptional" : "0",
29
+ "dataType" : "Float16",
30
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
31
+ "shortDescription" : "",
32
+ "shape" : "[1, 64, 16384]",
33
+ "name" : "logits_2",
34
+ "type" : "MultiArray"
35
+ },
36
+ {
37
+ "hasShapeFlexibility" : "0",
38
+ "isOptional" : "0",
39
+ "dataType" : "Float16",
40
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
41
+ "shortDescription" : "",
42
+ "shape" : "[1, 64, 16384]",
43
+ "name" : "logits_3",
44
+ "type" : "MultiArray"
45
+ },
46
+ {
47
+ "hasShapeFlexibility" : "0",
48
+ "isOptional" : "0",
49
+ "dataType" : "Float16",
50
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
51
+ "shortDescription" : "",
52
+ "shape" : "[1, 64, 16384]",
53
+ "name" : "logits_4",
54
+ "type" : "MultiArray"
55
+ },
56
+ {
57
+ "hasShapeFlexibility" : "0",
58
+ "isOptional" : "0",
59
+ "dataType" : "Float16",
60
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
61
+ "shortDescription" : "",
62
+ "shape" : "[1, 64, 16384]",
63
+ "name" : "logits_5",
64
+ "type" : "MultiArray"
65
+ },
66
+ {
67
+ "hasShapeFlexibility" : "0",
68
+ "isOptional" : "0",
69
+ "dataType" : "Float16",
70
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
71
+ "shortDescription" : "",
72
+ "shape" : "[1, 64, 16384]",
73
+ "name" : "logits_6",
74
+ "type" : "MultiArray"
75
+ },
76
+ {
77
+ "hasShapeFlexibility" : "0",
78
+ "isOptional" : "0",
79
+ "dataType" : "Float16",
80
+ "formattedType" : "MultiArray (Float16 1 × 64 × 13568)",
81
+ "shortDescription" : "",
82
+ "shape" : "[1, 64, 13568]",
83
+ "name" : "logits_7",
84
  "type" : "MultiArray"
85
  }
86
  ],
 
89
  ],
90
  "specificationVersion" : 7,
91
  "mlProgramOperationTypeHistogram" : {
92
+ "Concat" : 1,
93
  "Ios16.mul" : 2,
94
  "Squeeze" : 1,
95
  "Transpose" : 1,
 
128
  "type" : "MultiArray"
129
  }
130
  ],
131
+ "generatedClassName" : "Llama_3_2_1B_Instruct_2024_10_13_15_34_32_chunk3",
132
  "method" : "predict"
133
  }
134
  ]
Llama-3.2-1B-Instruct_chunk6.mlmodelc/model.mil CHANGED
@@ -27,51 +27,48 @@ program(1.0)
27
  tensor<fp16, [2048, 16384]> transpose_1_to_fp16 = const()[name = tensor<string, []>("transpose_1_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(4416)))];
28
  tensor<fp16, [64, 16384]> matmul_0_cast_fp16 = matmul(transpose_x = matmul_0_transpose_x_0, transpose_y = matmul_0_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_1_to_fp16)[name = tensor<string, []>("matmul_0_cast_fp16")];
29
  tensor<int32, [3]> concat_8 = const()[name = tensor<string, []>("concat_8"), val = tensor<int32, [3]>([1, 64, 16384])];
30
- tensor<fp16, [1, 64, 16384]> reshape_2_cast_fp16 = reshape(shape = concat_8, x = matmul_0_cast_fp16)[name = tensor<string, []>("reshape_2_cast_fp16")];
31
  tensor<bool, []> matmul_1_transpose_x_0 = const()[name = tensor<string, []>("matmul_1_transpose_x_0"), val = tensor<bool, []>(false)];
32
  tensor<bool, []> matmul_1_transpose_y_0 = const()[name = tensor<string, []>("matmul_1_transpose_y_0"), val = tensor<bool, []>(false)];
33
  tensor<fp16, [2048, 16384]> transpose_3_to_fp16 = const()[name = tensor<string, []>("transpose_3_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(67113344)))];
34
  tensor<fp16, [64, 16384]> matmul_1_cast_fp16 = matmul(transpose_x = matmul_1_transpose_x_0, transpose_y = matmul_1_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_3_to_fp16)[name = tensor<string, []>("matmul_1_cast_fp16")];
35
  tensor<int32, [3]> concat_16 = const()[name = tensor<string, []>("concat_16"), val = tensor<int32, [3]>([1, 64, 16384])];
36
- tensor<fp16, [1, 64, 16384]> reshape_5_cast_fp16 = reshape(shape = concat_16, x = matmul_1_cast_fp16)[name = tensor<string, []>("reshape_5_cast_fp16")];
37
  tensor<bool, []> matmul_2_transpose_x_0 = const()[name = tensor<string, []>("matmul_2_transpose_x_0"), val = tensor<bool, []>(false)];
38
  tensor<bool, []> matmul_2_transpose_y_0 = const()[name = tensor<string, []>("matmul_2_transpose_y_0"), val = tensor<bool, []>(false)];
39
  tensor<fp16, [2048, 16384]> transpose_5_to_fp16 = const()[name = tensor<string, []>("transpose_5_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(134222272)))];
40
  tensor<fp16, [64, 16384]> matmul_2_cast_fp16 = matmul(transpose_x = matmul_2_transpose_x_0, transpose_y = matmul_2_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_5_to_fp16)[name = tensor<string, []>("matmul_2_cast_fp16")];
41
  tensor<int32, [3]> concat_24 = const()[name = tensor<string, []>("concat_24"), val = tensor<int32, [3]>([1, 64, 16384])];
42
- tensor<fp16, [1, 64, 16384]> reshape_8_cast_fp16 = reshape(shape = concat_24, x = matmul_2_cast_fp16)[name = tensor<string, []>("reshape_8_cast_fp16")];
43
  tensor<bool, []> matmul_3_transpose_x_0 = const()[name = tensor<string, []>("matmul_3_transpose_x_0"), val = tensor<bool, []>(false)];
44
  tensor<bool, []> matmul_3_transpose_y_0 = const()[name = tensor<string, []>("matmul_3_transpose_y_0"), val = tensor<bool, []>(false)];
45
  tensor<fp16, [2048, 16384]> transpose_7_to_fp16 = const()[name = tensor<string, []>("transpose_7_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(201331200)))];
46
  tensor<fp16, [64, 16384]> matmul_3_cast_fp16 = matmul(transpose_x = matmul_3_transpose_x_0, transpose_y = matmul_3_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_7_to_fp16)[name = tensor<string, []>("matmul_3_cast_fp16")];
47
  tensor<int32, [3]> concat_32 = const()[name = tensor<string, []>("concat_32"), val = tensor<int32, [3]>([1, 64, 16384])];
48
- tensor<fp16, [1, 64, 16384]> reshape_11_cast_fp16 = reshape(shape = concat_32, x = matmul_3_cast_fp16)[name = tensor<string, []>("reshape_11_cast_fp16")];
49
  tensor<bool, []> matmul_4_transpose_x_0 = const()[name = tensor<string, []>("matmul_4_transpose_x_0"), val = tensor<bool, []>(false)];
50
  tensor<bool, []> matmul_4_transpose_y_0 = const()[name = tensor<string, []>("matmul_4_transpose_y_0"), val = tensor<bool, []>(false)];
51
  tensor<fp16, [2048, 16384]> transpose_9_to_fp16 = const()[name = tensor<string, []>("transpose_9_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(268440128)))];
52
  tensor<fp16, [64, 16384]> matmul_4_cast_fp16 = matmul(transpose_x = matmul_4_transpose_x_0, transpose_y = matmul_4_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_9_to_fp16)[name = tensor<string, []>("matmul_4_cast_fp16")];
53
  tensor<int32, [3]> concat_40 = const()[name = tensor<string, []>("concat_40"), val = tensor<int32, [3]>([1, 64, 16384])];
54
- tensor<fp16, [1, 64, 16384]> reshape_14_cast_fp16 = reshape(shape = concat_40, x = matmul_4_cast_fp16)[name = tensor<string, []>("reshape_14_cast_fp16")];
55
  tensor<bool, []> matmul_5_transpose_x_0 = const()[name = tensor<string, []>("matmul_5_transpose_x_0"), val = tensor<bool, []>(false)];
56
  tensor<bool, []> matmul_5_transpose_y_0 = const()[name = tensor<string, []>("matmul_5_transpose_y_0"), val = tensor<bool, []>(false)];
57
  tensor<fp16, [2048, 16384]> transpose_11_to_fp16 = const()[name = tensor<string, []>("transpose_11_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(335549056)))];
58
  tensor<fp16, [64, 16384]> matmul_5_cast_fp16 = matmul(transpose_x = matmul_5_transpose_x_0, transpose_y = matmul_5_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_11_to_fp16)[name = tensor<string, []>("matmul_5_cast_fp16")];
59
  tensor<int32, [3]> concat_48 = const()[name = tensor<string, []>("concat_48"), val = tensor<int32, [3]>([1, 64, 16384])];
60
- tensor<fp16, [1, 64, 16384]> reshape_17_cast_fp16 = reshape(shape = concat_48, x = matmul_5_cast_fp16)[name = tensor<string, []>("reshape_17_cast_fp16")];
61
  tensor<bool, []> matmul_6_transpose_x_0 = const()[name = tensor<string, []>("matmul_6_transpose_x_0"), val = tensor<bool, []>(false)];
62
  tensor<bool, []> matmul_6_transpose_y_0 = const()[name = tensor<string, []>("matmul_6_transpose_y_0"), val = tensor<bool, []>(false)];
63
  tensor<fp16, [2048, 16384]> transpose_13_to_fp16 = const()[name = tensor<string, []>("transpose_13_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(402657984)))];
64
  tensor<fp16, [64, 16384]> matmul_6_cast_fp16 = matmul(transpose_x = matmul_6_transpose_x_0, transpose_y = matmul_6_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_13_to_fp16)[name = tensor<string, []>("matmul_6_cast_fp16")];
65
  tensor<int32, [3]> concat_56 = const()[name = tensor<string, []>("concat_56"), val = tensor<int32, [3]>([1, 64, 16384])];
66
- tensor<fp16, [1, 64, 16384]> reshape_20_cast_fp16 = reshape(shape = concat_56, x = matmul_6_cast_fp16)[name = tensor<string, []>("reshape_20_cast_fp16")];
67
  tensor<bool, []> matmul_7_transpose_x_0 = const()[name = tensor<string, []>("matmul_7_transpose_x_0"), val = tensor<bool, []>(false)];
68
  tensor<bool, []> matmul_7_transpose_y_0 = const()[name = tensor<string, []>("matmul_7_transpose_y_0"), val = tensor<bool, []>(false)];
69
  tensor<fp16, [2048, 13568]> transpose_15_to_fp16 = const()[name = tensor<string, []>("transpose_15_to_fp16"), val = tensor<fp16, [2048, 13568]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(469766912)))];
70
  tensor<fp16, [64, 13568]> matmul_7_cast_fp16 = matmul(transpose_x = matmul_7_transpose_x_0, transpose_y = matmul_7_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_15_to_fp16)[name = tensor<string, []>("matmul_7_cast_fp16")];
71
  tensor<int32, [3]> concat_64 = const()[name = tensor<string, []>("concat_64"), val = tensor<int32, [3]>([1, 64, 13568])];
72
- tensor<fp16, [1, 64, 13568]> reshape_23_cast_fp16 = reshape(shape = concat_64, x = matmul_7_cast_fp16)[name = tensor<string, []>("reshape_23_cast_fp16")];
73
- tensor<int32, []> var_99 = const()[name = tensor<string, []>("op_99"), val = tensor<int32, []>(-1)];
74
- tensor<bool, []> var_100_interleave_0 = const()[name = tensor<string, []>("op_100_interleave_0"), val = tensor<bool, []>(false)];
75
- tensor<fp16, [1, 64, 128256]> logits = concat(axis = var_99, interleave = var_100_interleave_0, values = (reshape_2_cast_fp16, reshape_5_cast_fp16, reshape_8_cast_fp16, reshape_11_cast_fp16, reshape_14_cast_fp16, reshape_17_cast_fp16, reshape_20_cast_fp16, reshape_23_cast_fp16))[name = tensor<string, []>("op_100_cast_fp16")];
76
- } -> (logits);
77
  }
 
27
  tensor<fp16, [2048, 16384]> transpose_1_to_fp16 = const()[name = tensor<string, []>("transpose_1_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(4416)))];
28
  tensor<fp16, [64, 16384]> matmul_0_cast_fp16 = matmul(transpose_x = matmul_0_transpose_x_0, transpose_y = matmul_0_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_1_to_fp16)[name = tensor<string, []>("matmul_0_cast_fp16")];
29
  tensor<int32, [3]> concat_8 = const()[name = tensor<string, []>("concat_8"), val = tensor<int32, [3]>([1, 64, 16384])];
30
+ tensor<fp16, [1, 64, 16384]> logits_0 = reshape(shape = concat_8, x = matmul_0_cast_fp16)[name = tensor<string, []>("reshape_2_cast_fp16")];
31
  tensor<bool, []> matmul_1_transpose_x_0 = const()[name = tensor<string, []>("matmul_1_transpose_x_0"), val = tensor<bool, []>(false)];
32
  tensor<bool, []> matmul_1_transpose_y_0 = const()[name = tensor<string, []>("matmul_1_transpose_y_0"), val = tensor<bool, []>(false)];
33
  tensor<fp16, [2048, 16384]> transpose_3_to_fp16 = const()[name = tensor<string, []>("transpose_3_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(67113344)))];
34
  tensor<fp16, [64, 16384]> matmul_1_cast_fp16 = matmul(transpose_x = matmul_1_transpose_x_0, transpose_y = matmul_1_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_3_to_fp16)[name = tensor<string, []>("matmul_1_cast_fp16")];
35
  tensor<int32, [3]> concat_16 = const()[name = tensor<string, []>("concat_16"), val = tensor<int32, [3]>([1, 64, 16384])];
36
+ tensor<fp16, [1, 64, 16384]> logits_1 = reshape(shape = concat_16, x = matmul_1_cast_fp16)[name = tensor<string, []>("reshape_5_cast_fp16")];
37
  tensor<bool, []> matmul_2_transpose_x_0 = const()[name = tensor<string, []>("matmul_2_transpose_x_0"), val = tensor<bool, []>(false)];
38
  tensor<bool, []> matmul_2_transpose_y_0 = const()[name = tensor<string, []>("matmul_2_transpose_y_0"), val = tensor<bool, []>(false)];
39
  tensor<fp16, [2048, 16384]> transpose_5_to_fp16 = const()[name = tensor<string, []>("transpose_5_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(134222272)))];
40
  tensor<fp16, [64, 16384]> matmul_2_cast_fp16 = matmul(transpose_x = matmul_2_transpose_x_0, transpose_y = matmul_2_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_5_to_fp16)[name = tensor<string, []>("matmul_2_cast_fp16")];
41
  tensor<int32, [3]> concat_24 = const()[name = tensor<string, []>("concat_24"), val = tensor<int32, [3]>([1, 64, 16384])];
42
+ tensor<fp16, [1, 64, 16384]> logits_2 = reshape(shape = concat_24, x = matmul_2_cast_fp16)[name = tensor<string, []>("reshape_8_cast_fp16")];
43
  tensor<bool, []> matmul_3_transpose_x_0 = const()[name = tensor<string, []>("matmul_3_transpose_x_0"), val = tensor<bool, []>(false)];
44
  tensor<bool, []> matmul_3_transpose_y_0 = const()[name = tensor<string, []>("matmul_3_transpose_y_0"), val = tensor<bool, []>(false)];
45
  tensor<fp16, [2048, 16384]> transpose_7_to_fp16 = const()[name = tensor<string, []>("transpose_7_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(201331200)))];
46
  tensor<fp16, [64, 16384]> matmul_3_cast_fp16 = matmul(transpose_x = matmul_3_transpose_x_0, transpose_y = matmul_3_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_7_to_fp16)[name = tensor<string, []>("matmul_3_cast_fp16")];
47
  tensor<int32, [3]> concat_32 = const()[name = tensor<string, []>("concat_32"), val = tensor<int32, [3]>([1, 64, 16384])];
48
+ tensor<fp16, [1, 64, 16384]> logits_3 = reshape(shape = concat_32, x = matmul_3_cast_fp16)[name = tensor<string, []>("reshape_11_cast_fp16")];
49
  tensor<bool, []> matmul_4_transpose_x_0 = const()[name = tensor<string, []>("matmul_4_transpose_x_0"), val = tensor<bool, []>(false)];
50
  tensor<bool, []> matmul_4_transpose_y_0 = const()[name = tensor<string, []>("matmul_4_transpose_y_0"), val = tensor<bool, []>(false)];
51
  tensor<fp16, [2048, 16384]> transpose_9_to_fp16 = const()[name = tensor<string, []>("transpose_9_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(268440128)))];
52
  tensor<fp16, [64, 16384]> matmul_4_cast_fp16 = matmul(transpose_x = matmul_4_transpose_x_0, transpose_y = matmul_4_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_9_to_fp16)[name = tensor<string, []>("matmul_4_cast_fp16")];
53
  tensor<int32, [3]> concat_40 = const()[name = tensor<string, []>("concat_40"), val = tensor<int32, [3]>([1, 64, 16384])];
54
+ tensor<fp16, [1, 64, 16384]> logits_4 = reshape(shape = concat_40, x = matmul_4_cast_fp16)[name = tensor<string, []>("reshape_14_cast_fp16")];
55
  tensor<bool, []> matmul_5_transpose_x_0 = const()[name = tensor<string, []>("matmul_5_transpose_x_0"), val = tensor<bool, []>(false)];
56
  tensor<bool, []> matmul_5_transpose_y_0 = const()[name = tensor<string, []>("matmul_5_transpose_y_0"), val = tensor<bool, []>(false)];
57
  tensor<fp16, [2048, 16384]> transpose_11_to_fp16 = const()[name = tensor<string, []>("transpose_11_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(335549056)))];
58
  tensor<fp16, [64, 16384]> matmul_5_cast_fp16 = matmul(transpose_x = matmul_5_transpose_x_0, transpose_y = matmul_5_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_11_to_fp16)[name = tensor<string, []>("matmul_5_cast_fp16")];
59
  tensor<int32, [3]> concat_48 = const()[name = tensor<string, []>("concat_48"), val = tensor<int32, [3]>([1, 64, 16384])];
60
+ tensor<fp16, [1, 64, 16384]> logits_5 = reshape(shape = concat_48, x = matmul_5_cast_fp16)[name = tensor<string, []>("reshape_17_cast_fp16")];
61
  tensor<bool, []> matmul_6_transpose_x_0 = const()[name = tensor<string, []>("matmul_6_transpose_x_0"), val = tensor<bool, []>(false)];
62
  tensor<bool, []> matmul_6_transpose_y_0 = const()[name = tensor<string, []>("matmul_6_transpose_y_0"), val = tensor<bool, []>(false)];
63
  tensor<fp16, [2048, 16384]> transpose_13_to_fp16 = const()[name = tensor<string, []>("transpose_13_to_fp16"), val = tensor<fp16, [2048, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(402657984)))];
64
  tensor<fp16, [64, 16384]> matmul_6_cast_fp16 = matmul(transpose_x = matmul_6_transpose_x_0, transpose_y = matmul_6_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_13_to_fp16)[name = tensor<string, []>("matmul_6_cast_fp16")];
65
  tensor<int32, [3]> concat_56 = const()[name = tensor<string, []>("concat_56"), val = tensor<int32, [3]>([1, 64, 16384])];
66
+ tensor<fp16, [1, 64, 16384]> logits_6 = reshape(shape = concat_56, x = matmul_6_cast_fp16)[name = tensor<string, []>("reshape_20_cast_fp16")];
67
  tensor<bool, []> matmul_7_transpose_x_0 = const()[name = tensor<string, []>("matmul_7_transpose_x_0"), val = tensor<bool, []>(false)];
68
  tensor<bool, []> matmul_7_transpose_y_0 = const()[name = tensor<string, []>("matmul_7_transpose_y_0"), val = tensor<bool, []>(false)];
69
  tensor<fp16, [2048, 13568]> transpose_15_to_fp16 = const()[name = tensor<string, []>("transpose_15_to_fp16"), val = tensor<fp16, [2048, 13568]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(469766912)))];
70
  tensor<fp16, [64, 13568]> matmul_7_cast_fp16 = matmul(transpose_x = matmul_7_transpose_x_0, transpose_y = matmul_7_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_15_to_fp16)[name = tensor<string, []>("matmul_7_cast_fp16")];
71
  tensor<int32, [3]> concat_64 = const()[name = tensor<string, []>("concat_64"), val = tensor<int32, [3]>([1, 64, 13568])];
72
+ tensor<fp16, [1, 64, 13568]> logits_7 = reshape(shape = concat_64, x = matmul_7_cast_fp16)[name = tensor<string, []>("reshape_23_cast_fp16")];
73
+ } -> (logits_0, logits_1, logits_2, logits_3, logits_4, logits_5, logits_6, logits_7);
 
 
 
74
  }
logit-processor.mlmodelc/analytics/coremldata.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0ad03dc247f59282bf008d857db8620b0ad600eb939bfa2a4e8a78438e1c2573
3
  size 243
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cea8f79e82c95d93f772797047802fb88c7fc82dfcef790e69a2f274a104623
3
  size 243
logit-processor.mlmodelc/coremldata.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ccca55190c5da56bfc175471f3239eeeb7bffece8d38d565de9443edef9c9148
3
- size 378
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d70f289c1552a24ba6ca721405ea0653ac7125a91297ea3d32b30363c2afd3c
3
+ size 503
logit-processor.mlmodelc/metadata.json CHANGED
@@ -6,9 +6,9 @@
6
  "hasShapeFlexibility" : "0",
7
  "isOptional" : "0",
8
  "dataType" : "Int32",
9
- "formattedType" : "MultiArray (Int32)",
10
  "shortDescription" : "",
11
- "shape" : "[]",
12
  "name" : "argmax",
13
  "type" : "MultiArray"
14
  }
@@ -18,7 +18,11 @@
18
  ],
19
  "specificationVersion" : 7,
20
  "mlProgramOperationTypeHistogram" : {
21
- "Ios16.reduceArgmax" : 1
 
 
 
 
22
  },
23
  "computePrecision" : "Mixed (Float16, Int32)",
24
  "isUpdatable" : "0",
@@ -40,19 +44,87 @@
40
  },
41
  "inputSchema" : [
42
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  "shortDescription" : "",
 
 
 
 
 
 
 
44
  "dataType" : "Float16",
45
- "hasShapeFlexibility" : "1",
 
 
 
 
 
 
 
46
  "isOptional" : "0",
47
- "shapeFlexibility" : "1 × 511 × 32000 | 1 × 1 × 32000 | 1 × 2 × 32000 | 1 × 4 × 32000 | 1 × 64 × 32000 | 1 × 64 × 128256 | 1 × 512 × 32000",
48
- "formattedType" : "MultiArray (Float16 1 × 511 × 32000)",
49
- "type" : "MultiArray",
50
- "shape" : "[1, 511, 32000]",
51
- "name" : "logits",
52
- "enumeratedShapes" : "[[1, 511, 32000], [1, 1, 32000], [1, 2, 32000], [1, 4, 32000], [1, 64, 32000], [1, 64, 128256], [1, 512, 32000]]"
53
  }
54
  ],
55
- "generatedClassName" : "logit_processor",
56
  "method" : "predict"
57
  }
58
  ]
 
6
  "hasShapeFlexibility" : "0",
7
  "isOptional" : "0",
8
  "dataType" : "Int32",
9
+ "formattedType" : "MultiArray (Int32 1 × 64)",
10
  "shortDescription" : "",
11
+ "shape" : "[1, 64]",
12
  "name" : "argmax",
13
  "type" : "MultiArray"
14
  }
 
18
  ],
19
  "specificationVersion" : 7,
20
  "mlProgramOperationTypeHistogram" : {
21
+ "Ios16.add" : 7,
22
+ "Ios16.topk" : 9,
23
+ "Ios16.gatherAlongAxis" : 1,
24
+ "Concat" : 2,
25
+ "Squeeze" : 1
26
  },
27
  "computePrecision" : "Mixed (Float16, Int32)",
28
  "isUpdatable" : "0",
 
44
  },
45
  "inputSchema" : [
46
  {
47
+ "hasShapeFlexibility" : "0",
48
+ "isOptional" : "0",
49
+ "dataType" : "Float16",
50
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
51
+ "shortDescription" : "",
52
+ "shape" : "[1, 64, 16384]",
53
+ "name" : "logits_0",
54
+ "type" : "MultiArray"
55
+ },
56
+ {
57
+ "hasShapeFlexibility" : "0",
58
+ "isOptional" : "0",
59
+ "dataType" : "Float16",
60
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
61
+ "shortDescription" : "",
62
+ "shape" : "[1, 64, 16384]",
63
+ "name" : "logits_1",
64
+ "type" : "MultiArray"
65
+ },
66
+ {
67
+ "hasShapeFlexibility" : "0",
68
+ "isOptional" : "0",
69
+ "dataType" : "Float16",
70
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
71
+ "shortDescription" : "",
72
+ "shape" : "[1, 64, 16384]",
73
+ "name" : "logits_2",
74
+ "type" : "MultiArray"
75
+ },
76
+ {
77
+ "hasShapeFlexibility" : "0",
78
+ "isOptional" : "0",
79
+ "dataType" : "Float16",
80
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
81
+ "shortDescription" : "",
82
+ "shape" : "[1, 64, 16384]",
83
+ "name" : "logits_3",
84
+ "type" : "MultiArray"
85
+ },
86
+ {
87
+ "hasShapeFlexibility" : "0",
88
+ "isOptional" : "0",
89
+ "dataType" : "Float16",
90
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
91
+ "shortDescription" : "",
92
+ "shape" : "[1, 64, 16384]",
93
+ "name" : "logits_4",
94
+ "type" : "MultiArray"
95
+ },
96
+ {
97
+ "hasShapeFlexibility" : "0",
98
+ "isOptional" : "0",
99
+ "dataType" : "Float16",
100
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
101
  "shortDescription" : "",
102
+ "shape" : "[1, 64, 16384]",
103
+ "name" : "logits_5",
104
+ "type" : "MultiArray"
105
+ },
106
+ {
107
+ "hasShapeFlexibility" : "0",
108
+ "isOptional" : "0",
109
  "dataType" : "Float16",
110
+ "formattedType" : "MultiArray (Float16 1 × 64 × 16384)",
111
+ "shortDescription" : "",
112
+ "shape" : "[1, 64, 16384]",
113
+ "name" : "logits_6",
114
+ "type" : "MultiArray"
115
+ },
116
+ {
117
+ "hasShapeFlexibility" : "0",
118
  "isOptional" : "0",
119
+ "dataType" : "Float16",
120
+ "formattedType" : "MultiArray (Float16 1 × 64 × 13568)",
121
+ "shortDescription" : "",
122
+ "shape" : "[1, 64, 13568]",
123
+ "name" : "logits_7",
124
+ "type" : "MultiArray"
125
  }
126
  ],
127
+ "generatedClassName" : "split_logit_processor",
128
  "method" : "predict"
129
  }
130
  ]
logit-processor.mlmodelc/model.mil CHANGED
@@ -1,9 +1,84 @@
1
  program(1.0)
2
  [buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})]
3
  {
4
- func main<ios16>(tensor<fp16, [1, ?, ?]> logits) [FlexibleShapeInformation = tuple<tuple<tensor<string, []>, dict<tensor<string, []>, tensor<int32, [?]>>>, tuple<tensor<string, []>, dict<tensor<string, []>, dict<tensor<string, []>, tensor<int32, [?]>>>>>((("DefaultShapes", {{"logits", [1, 511, 32000]}}), ("EnumeratedShapes", {{"logits_1_1_1_1_32000_", {{"logits", [1, 1, 32000]}}}, {"logits_1_1_1_2_32000_", {{"logits", [1, 2, 32000]}}}, {"logits_1_1_1_4_32000_", {{"logits", [1, 4, 32000]}}}, {"logits_1_1_1_511_32000_", {{"logits", [1, 511, 32000]}}}, {"logits_1_1_1_512_32000_", {{"logits", [1, 512, 32000]}}}, {"logits_1_1_1_64_128256_", {{"logits", [1, 64, 128256]}}}, {"logits_1_1_1_64_32000_", {{"logits", [1, 64, 32000]}}}})))] {
5
- tensor<int32, []> var_2 = const()[name = tensor<string, []>("op_2"), val = tensor<int32, []>(-1)];
6
- tensor<bool, []> var_3 = const()[name = tensor<string, []>("op_3"), val = tensor<bool, []>(false)];
7
- tensor<int32, [1, ?]> argmax = reduce_argmax(axis = var_2, keep_dims = var_3, x = logits)[name = tensor<string, []>("op_4_cast_fp16")];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  } -> (argmax);
9
  }
 
1
  program(1.0)
2
  [buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})]
3
  {
4
+ func main<ios16>(tensor<fp16, [1, 64, 16384]> logits_0, tensor<fp16, [1, 64, 16384]> logits_1, tensor<fp16, [1, 64, 16384]> logits_2, tensor<fp16, [1, 64, 16384]> logits_3, tensor<fp16, [1, 64, 16384]> logits_4, tensor<fp16, [1, 64, 16384]> logits_5, tensor<fp16, [1, 64, 16384]> logits_6, tensor<fp16, [1, 64, 13568]> logits_7) {
5
+ tensor<int32, [1]> chunk_size = const()[name = tensor<string, []>("chunk_size"), val = tensor<int32, [1]>([16384])];
6
+ tensor<int32, []> var_12 = const()[name = tensor<string, []>("op_12"), val = tensor<int32, []>(1)];
7
+ tensor<int32, []> var_16_axis_0 = const()[name = tensor<string, []>("op_16_axis_0"), val = tensor<int32, []>(-1)];
8
+ tensor<bool, []> var_16_ascending_0 = const()[name = tensor<string, []>("op_16_ascending_0"), val = tensor<bool, []>(false)];
9
+ tensor<bool, []> var_16_sort_0 = const()[name = tensor<string, []>("op_16_sort_0"), val = tensor<bool, []>(false)];
10
+ tensor<bool, []> var_16_return_indices_0 = const()[name = tensor<string, []>("op_16_return_indices_0"), val = tensor<bool, []>(true)];
11
+ tensor<fp16, [1, 64, 1]> var_16_cast_fp16_0, tensor<int32, [1, 64, 1]> var_16_cast_fp16_1 = topk(ascending = var_16_ascending_0, axis = var_16_axis_0, k = var_12, return_indices = var_16_return_indices_0, sort = var_16_sort_0, x = logits_0)[name = tensor<string, []>("op_16_cast_fp16")];
12
+ tensor<int32, []> var_22 = const()[name = tensor<string, []>("op_22"), val = tensor<int32, []>(1)];
13
+ tensor<int32, []> var_26_axis_0 = const()[name = tensor<string, []>("op_26_axis_0"), val = tensor<int32, []>(-1)];
14
+ tensor<bool, []> var_26_ascending_0 = const()[name = tensor<string, []>("op_26_ascending_0"), val = tensor<bool, []>(false)];
15
+ tensor<bool, []> var_26_sort_0 = const()[name = tensor<string, []>("op_26_sort_0"), val = tensor<bool, []>(false)];
16
+ tensor<bool, []> var_26_return_indices_0 = const()[name = tensor<string, []>("op_26_return_indices_0"), val = tensor<bool, []>(true)];
17
+ tensor<fp16, [1, 64, 1]> var_26_cast_fp16_0, tensor<int32, [1, 64, 1]> var_26_cast_fp16_1 = topk(ascending = var_26_ascending_0, axis = var_26_axis_0, k = var_22, return_indices = var_26_return_indices_0, sort = var_26_sort_0, x = logits_1)[name = tensor<string, []>("op_26_cast_fp16")];
18
+ tensor<int32, [1, 64, 1]> var_31 = add(x = var_26_cast_fp16_1, y = chunk_size)[name = tensor<string, []>("op_31")];
19
+ tensor<int32, []> var_32 = const()[name = tensor<string, []>("op_32"), val = tensor<int32, []>(1)];
20
+ tensor<int32, []> var_36_axis_0 = const()[name = tensor<string, []>("op_36_axis_0"), val = tensor<int32, []>(-1)];
21
+ tensor<bool, []> var_36_ascending_0 = const()[name = tensor<string, []>("op_36_ascending_0"), val = tensor<bool, []>(false)];
22
+ tensor<bool, []> var_36_sort_0 = const()[name = tensor<string, []>("op_36_sort_0"), val = tensor<bool, []>(false)];
23
+ tensor<bool, []> var_36_return_indices_0 = const()[name = tensor<string, []>("op_36_return_indices_0"), val = tensor<bool, []>(true)];
24
+ tensor<fp16, [1, 64, 1]> var_36_cast_fp16_0, tensor<int32, [1, 64, 1]> var_36_cast_fp16_1 = topk(ascending = var_36_ascending_0, axis = var_36_axis_0, k = var_32, return_indices = var_36_return_indices_0, sort = var_36_sort_0, x = logits_2)[name = tensor<string, []>("op_36_cast_fp16")];
25
+ tensor<int32, [1]> var_39 = const()[name = tensor<string, []>("op_39"), val = tensor<int32, [1]>([32768])];
26
+ tensor<int32, [1, 64, 1]> var_41 = add(x = var_36_cast_fp16_1, y = var_39)[name = tensor<string, []>("op_41")];
27
+ tensor<int32, []> var_42 = const()[name = tensor<string, []>("op_42"), val = tensor<int32, []>(1)];
28
+ tensor<int32, []> var_46_axis_0 = const()[name = tensor<string, []>("op_46_axis_0"), val = tensor<int32, []>(-1)];
29
+ tensor<bool, []> var_46_ascending_0 = const()[name = tensor<string, []>("op_46_ascending_0"), val = tensor<bool, []>(false)];
30
+ tensor<bool, []> var_46_sort_0 = const()[name = tensor<string, []>("op_46_sort_0"), val = tensor<bool, []>(false)];
31
+ tensor<bool, []> var_46_return_indices_0 = const()[name = tensor<string, []>("op_46_return_indices_0"), val = tensor<bool, []>(true)];
32
+ tensor<fp16, [1, 64, 1]> var_46_cast_fp16_0, tensor<int32, [1, 64, 1]> var_46_cast_fp16_1 = topk(ascending = var_46_ascending_0, axis = var_46_axis_0, k = var_42, return_indices = var_46_return_indices_0, sort = var_46_sort_0, x = logits_3)[name = tensor<string, []>("op_46_cast_fp16")];
33
+ tensor<int32, [1]> var_49 = const()[name = tensor<string, []>("op_49"), val = tensor<int32, [1]>([49152])];
34
+ tensor<int32, [1, 64, 1]> var_51 = add(x = var_46_cast_fp16_1, y = var_49)[name = tensor<string, []>("op_51")];
35
+ tensor<int32, []> var_52 = const()[name = tensor<string, []>("op_52"), val = tensor<int32, []>(1)];
36
+ tensor<int32, []> var_56_axis_0 = const()[name = tensor<string, []>("op_56_axis_0"), val = tensor<int32, []>(-1)];
37
+ tensor<bool, []> var_56_ascending_0 = const()[name = tensor<string, []>("op_56_ascending_0"), val = tensor<bool, []>(false)];
38
+ tensor<bool, []> var_56_sort_0 = const()[name = tensor<string, []>("op_56_sort_0"), val = tensor<bool, []>(false)];
39
+ tensor<bool, []> var_56_return_indices_0 = const()[name = tensor<string, []>("op_56_return_indices_0"), val = tensor<bool, []>(true)];
40
+ tensor<fp16, [1, 64, 1]> var_56_cast_fp16_0, tensor<int32, [1, 64, 1]> var_56_cast_fp16_1 = topk(ascending = var_56_ascending_0, axis = var_56_axis_0, k = var_52, return_indices = var_56_return_indices_0, sort = var_56_sort_0, x = logits_4)[name = tensor<string, []>("op_56_cast_fp16")];
41
+ tensor<int32, [1]> var_59 = const()[name = tensor<string, []>("op_59"), val = tensor<int32, [1]>([65536])];
42
+ tensor<int32, [1, 64, 1]> var_61 = add(x = var_56_cast_fp16_1, y = var_59)[name = tensor<string, []>("op_61")];
43
+ tensor<int32, []> var_62 = const()[name = tensor<string, []>("op_62"), val = tensor<int32, []>(1)];
44
+ tensor<int32, []> var_66_axis_0 = const()[name = tensor<string, []>("op_66_axis_0"), val = tensor<int32, []>(-1)];
45
+ tensor<bool, []> var_66_ascending_0 = const()[name = tensor<string, []>("op_66_ascending_0"), val = tensor<bool, []>(false)];
46
+ tensor<bool, []> var_66_sort_0 = const()[name = tensor<string, []>("op_66_sort_0"), val = tensor<bool, []>(false)];
47
+ tensor<bool, []> var_66_return_indices_0 = const()[name = tensor<string, []>("op_66_return_indices_0"), val = tensor<bool, []>(true)];
48
+ tensor<fp16, [1, 64, 1]> var_66_cast_fp16_0, tensor<int32, [1, 64, 1]> var_66_cast_fp16_1 = topk(ascending = var_66_ascending_0, axis = var_66_axis_0, k = var_62, return_indices = var_66_return_indices_0, sort = var_66_sort_0, x = logits_5)[name = tensor<string, []>("op_66_cast_fp16")];
49
+ tensor<int32, [1]> var_69 = const()[name = tensor<string, []>("op_69"), val = tensor<int32, [1]>([81920])];
50
+ tensor<int32, [1, 64, 1]> var_71 = add(x = var_66_cast_fp16_1, y = var_69)[name = tensor<string, []>("op_71")];
51
+ tensor<int32, []> var_72 = const()[name = tensor<string, []>("op_72"), val = tensor<int32, []>(1)];
52
+ tensor<int32, []> var_76_axis_0 = const()[name = tensor<string, []>("op_76_axis_0"), val = tensor<int32, []>(-1)];
53
+ tensor<bool, []> var_76_ascending_0 = const()[name = tensor<string, []>("op_76_ascending_0"), val = tensor<bool, []>(false)];
54
+ tensor<bool, []> var_76_sort_0 = const()[name = tensor<string, []>("op_76_sort_0"), val = tensor<bool, []>(false)];
55
+ tensor<bool, []> var_76_return_indices_0 = const()[name = tensor<string, []>("op_76_return_indices_0"), val = tensor<bool, []>(true)];
56
+ tensor<fp16, [1, 64, 1]> var_76_cast_fp16_0, tensor<int32, [1, 64, 1]> var_76_cast_fp16_1 = topk(ascending = var_76_ascending_0, axis = var_76_axis_0, k = var_72, return_indices = var_76_return_indices_0, sort = var_76_sort_0, x = logits_6)[name = tensor<string, []>("op_76_cast_fp16")];
57
+ tensor<int32, [1]> var_79 = const()[name = tensor<string, []>("op_79"), val = tensor<int32, [1]>([98304])];
58
+ tensor<int32, [1, 64, 1]> var_81 = add(x = var_76_cast_fp16_1, y = var_79)[name = tensor<string, []>("op_81")];
59
+ tensor<int32, []> var_82 = const()[name = tensor<string, []>("op_82"), val = tensor<int32, []>(1)];
60
+ tensor<int32, []> cv_axis_0 = const()[name = tensor<string, []>("cv_axis_0"), val = tensor<int32, []>(-1)];
61
+ tensor<bool, []> cv_ascending_0 = const()[name = tensor<string, []>("cv_ascending_0"), val = tensor<bool, []>(false)];
62
+ tensor<bool, []> cv_sort_0 = const()[name = tensor<string, []>("cv_sort_0"), val = tensor<bool, []>(false)];
63
+ tensor<bool, []> cv_return_indices_0 = const()[name = tensor<string, []>("cv_return_indices_0"), val = tensor<bool, []>(true)];
64
+ tensor<fp16, [1, 64, 1]> cv_cast_fp16_0, tensor<int32, [1, 64, 1]> cv_cast_fp16_1 = topk(ascending = cv_ascending_0, axis = cv_axis_0, k = var_82, return_indices = cv_return_indices_0, sort = cv_sort_0, x = logits_7)[name = tensor<string, []>("cv_cast_fp16")];
65
+ tensor<int32, [1]> var_89 = const()[name = tensor<string, []>("op_89"), val = tensor<int32, [1]>([114688])];
66
+ tensor<int32, [1, 64, 1]> var_91 = add(x = cv_cast_fp16_1, y = var_89)[name = tensor<string, []>("op_91")];
67
+ tensor<int32, []> var_93 = const()[name = tensor<string, []>("op_93"), val = tensor<int32, []>(-1)];
68
+ tensor<bool, []> values_interleave_0 = const()[name = tensor<string, []>("values_interleave_0"), val = tensor<bool, []>(false)];
69
+ tensor<fp16, [1, 64, 8]> values_cast_fp16 = concat(axis = var_93, interleave = values_interleave_0, values = (var_16_cast_fp16_0, var_26_cast_fp16_0, var_36_cast_fp16_0, var_46_cast_fp16_0, var_56_cast_fp16_0, var_66_cast_fp16_0, var_76_cast_fp16_0, cv_cast_fp16_0))[name = tensor<string, []>("values_cast_fp16")];
70
+ tensor<int32, []> var_96 = const()[name = tensor<string, []>("op_96"), val = tensor<int32, []>(-1)];
71
+ tensor<bool, []> indices_interleave_0 = const()[name = tensor<string, []>("indices_interleave_0"), val = tensor<bool, []>(false)];
72
+ tensor<int32, [1, 64, 8]> indices = concat(axis = var_96, interleave = indices_interleave_0, values = (var_16_cast_fp16_1, var_31, var_41, var_51, var_61, var_71, var_81, var_91))[name = tensor<string, []>("indices")];
73
+ tensor<int32, []> var_98 = const()[name = tensor<string, []>("op_98"), val = tensor<int32, []>(1)];
74
+ tensor<int32, []> var_102_axis_0 = const()[name = tensor<string, []>("op_102_axis_0"), val = tensor<int32, []>(-1)];
75
+ tensor<bool, []> var_102_ascending_0 = const()[name = tensor<string, []>("op_102_ascending_0"), val = tensor<bool, []>(false)];
76
+ tensor<bool, []> var_102_sort_0 = const()[name = tensor<string, []>("op_102_sort_0"), val = tensor<bool, []>(true)];
77
+ tensor<bool, []> var_102_return_indices_0 = const()[name = tensor<string, []>("op_102_return_indices_0"), val = tensor<bool, []>(true)];
78
+ tensor<fp16, [1, 64, 1]> var_102_cast_fp16_0, tensor<int32, [1, 64, 1]> var_102_cast_fp16_1 = topk(ascending = var_102_ascending_0, axis = var_102_axis_0, k = var_98, return_indices = var_102_return_indices_0, sort = var_102_sort_0, x = values_cast_fp16)[name = tensor<string, []>("op_102_cast_fp16")];
79
+ tensor<int32, []> var_104 = const()[name = tensor<string, []>("op_104"), val = tensor<int32, []>(-1)];
80
+ tensor<int32, [1, 64, 1]> var_106 = gather_along_axis(axis = var_104, indices = var_102_cast_fp16_1, x = indices)[name = tensor<string, []>("op_106")];
81
+ tensor<int32, [1]> var_108_axes_0 = const()[name = tensor<string, []>("op_108_axes_0"), val = tensor<int32, [1]>([-1])];
82
+ tensor<int32, [1, 64]> argmax = squeeze(axes = var_108_axes_0, x = var_106)[name = tensor<string, []>("op_108")];
83
  } -> (argmax);
84
  }