File size: 3,453 Bytes
a76a14d
dba673f
a76a14d
dba673f
 
a76a14d
dba673f
a76a14d
 
dba673f
 
 
 
 
a76a14d
dba673f
a76a14d
dba673f
a76a14d
dba673f
a76a14d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
program(1.0)
[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})]
{
    func main<ios16>(tensor<fp16, [1, 64, 1, 4096]> new_k_cache, tensor<fp16, [1, 4096, 1, 64]> new_v_cache, tensor<fp16, [1, 448, 1, 4096]> old_k_cache, tensor<fp16, [1, 4096, 1, 448]> old_v_cache) {
            tensor<int32, []> var_6 = const()[name = tensor<string, []>("op_6"), val = tensor<int32, []>(-3)];
            tensor<bool, []> cat_k_1_interleave_0 = const()[name = tensor<string, []>("cat_k_1_interleave_0"), val = tensor<bool, []>(false)];
            tensor<fp16, [1, 512, 1, 4096]> cat_k_1_cast_fp16 = concat(axis = var_6, interleave = cat_k_1_interleave_0, values = (old_k_cache, new_k_cache))[name = tensor<string, []>("cat_k_1_cast_fp16")];
            tensor<int32, []> var_9 = const()[name = tensor<string, []>("op_9"), val = tensor<int32, []>(-1)];
            tensor<bool, []> cat_v_interleave_0 = const()[name = tensor<string, []>("cat_v_interleave_0"), val = tensor<bool, []>(false)];
            tensor<fp16, [1, 4096, 1, 512]> cat_v_cast_fp16 = concat(axis = var_9, interleave = cat_v_interleave_0, values = (old_v_cache, new_v_cache))[name = tensor<string, []>("cat_v_cast_fp16")];
            tensor<int32, [4]> var_20_begin_0 = const()[name = tensor<string, []>("op_20_begin_0"), val = tensor<int32, [4]>([0, 1, 0, 0])];
            tensor<int32, [4]> var_20_end_0 = const()[name = tensor<string, []>("op_20_end_0"), val = tensor<int32, [4]>([1, 449, 1, 4096])];
            tensor<bool, [4]> var_20_end_mask_0 = const()[name = tensor<string, []>("op_20_end_mask_0"), val = tensor<bool, [4]>([true, false, true, true])];
            tensor<fp16, [1, 448, 1, 4096]> generation_k_cache = slice_by_index(begin = var_20_begin_0, end = var_20_end_0, end_mask = var_20_end_mask_0, x = cat_k_1_cast_fp16)[name = tensor<string, []>("op_20_cast_fp16")];
            tensor<int32, [4]> var_50_begin_0 = const()[name = tensor<string, []>("op_50_begin_0"), val = tensor<int32, [4]>([0, 0, 0, 1])];
            tensor<int32, [4]> var_50_end_0 = const()[name = tensor<string, []>("op_50_end_0"), val = tensor<int32, [4]>([1, 4096, 1, 449])];
            tensor<bool, [4]> var_50_end_mask_0 = const()[name = tensor<string, []>("op_50_end_mask_0"), val = tensor<bool, [4]>([true, true, true, false])];
            tensor<fp16, [1, 4096, 1, 448]> generation_v_cache = slice_by_index(begin = var_50_begin_0, end = var_50_end_0, end_mask = var_50_end_mask_0, x = cat_v_cast_fp16)[name = tensor<string, []>("op_50_cast_fp16")];
            tensor<fp16, []> var_51_promoted_to_fp16 = const()[name = tensor<string, []>("op_51_promoted_to_fp16"), val = tensor<fp16, []>(0x1p+1)];
            tensor<fp16, [1, 448, 1, 4096]> prod_cast_fp16 = mul(x = generation_k_cache, y = var_51_promoted_to_fp16)[name = tensor<string, []>("prod_cast_fp16")];
            tensor<bool, []> var_53_keep_dims_0 = const()[name = tensor<string, []>("op_53_keep_dims_0"), val = tensor<bool, []>(false)];
            tensor<fp16, []> ignore_me_im_only_here_so_this_runs_on_the_ane = reduce_min(keep_dims = var_53_keep_dims_0, x = prod_cast_fp16)[name = tensor<string, []>("op_53_cast_fp16")];
        } -> (generation_k_cache, generation_v_cache, ignore_me_im_only_here_so_this_runs_on_the_ane);
}