smpanaro's picture
Update Sonoma model with faster 8x8 conv and split einsum attention
dba673f
raw
history blame
6.13 kB
program(1.0)
[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})]
{
func main<ios16>(tensor<fp16, [1, 4096, 8, 8]> x) {
tensor<bool, []> var_6 = const()[name = tensor<string, []>("op_6"), val = tensor<bool, []>(true)];
tensor<int32, []> var_9 = const()[name = tensor<string, []>("op_9"), val = tensor<int32, []>(1)];
tensor<bool, []> x_eps_interleave_0 = const()[name = tensor<string, []>("x_eps_interleave_0"), val = tensor<bool, []>(false)];
tensor<fp16, [1, 1, 8, 8]> eps_chan_to_fp16 = const()[name = tensor<string, []>("eps_chan_to_fp16"), val = tensor<fp16, [1, 1, 8, 8]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(64)))];
tensor<fp16, [1, 4097, 8, 8]> x_eps_cast_fp16 = concat(axis = var_9, interleave = x_eps_interleave_0, values = (x, eps_chan_to_fp16))[name = tensor<string, []>("x_eps_cast_fp16")];
tensor<int32, [1]> norm_x_axes_0 = const()[name = tensor<string, []>("norm_x_axes_0"), val = tensor<int32, [1]>([1])];
tensor<fp16, [1, 1, 8, 8]> norm_x_cast_fp16 = reduce_l2_norm(axes = norm_x_axes_0, keep_dims = var_6, x = x_eps_cast_fp16)[name = tensor<string, []>("norm_x_cast_fp16")];
tensor<fp16, [1, 4096, 8, 8]> x_normed_1_cast_fp16 = real_div(x = x, y = norm_x_cast_fp16)[name = tensor<string, []>("x_normed_1_cast_fp16")];
tensor<fp16, []> var_34_to_fp16 = const()[name = tensor<string, []>("op_34_to_fp16"), val = tensor<fp16, []>(0x1p+6)];
tensor<fp16, [1, 4096, 8, 8]> x_normed_3_cast_fp16 = mul(x = x_normed_1_cast_fp16, y = var_34_to_fp16)[name = tensor<string, []>("x_normed_3_cast_fp16")];
tensor<fp16, [1, 4096, 1, 1]> ln_f_weight_to_fp16 = const()[name = tensor<string, []>("ln_f_weight_to_fp16"), val = tensor<fp16, [1, 4096, 1, 1]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(256)))];
tensor<fp16, [1, 4096, 8, 8]> x_5_cast_fp16 = mul(x = x_normed_3_cast_fp16, y = ln_f_weight_to_fp16)[name = tensor<string, []>("x_5_cast_fp16")];
tensor<int32, [4]> var_48 = const()[name = tensor<string, []>("op_48"), val = tensor<int32, [4]>([1, 4096, 1, -1])];
tensor<fp16, [1, 4096, 1, 64]> x_cast_fp16 = reshape(shape = var_48, x = x_5_cast_fp16)[name = tensor<string, []>("x_cast_fp16")];
tensor<int32, [1]> var_51_axes_0 = const()[name = tensor<string, []>("op_51_axes_0"), val = tensor<int32, [1]>([2])];
tensor<fp16, [1, 4096, 64]> var_51_cast_fp16 = squeeze(axes = var_51_axes_0, x = x_cast_fp16)[name = tensor<string, []>("op_51_cast_fp16")];
tensor<int32, [3]> var_54_perm_0 = const()[name = tensor<string, []>("op_54_perm_0"), val = tensor<int32, [3]>([0, 2, 1])];
tensor<int32, [2]> concat_4 = const()[name = tensor<string, []>("concat_4"), val = tensor<int32, [2]>([64, 4096])];
tensor<fp16, [1, 64, 4096]> var_54_cast_fp16 = transpose(perm = var_54_perm_0, x = var_51_cast_fp16)[name = tensor<string, []>("transpose_4")];
tensor<fp16, [64, 4096]> reshape_0_cast_fp16 = reshape(shape = concat_4, x = var_54_cast_fp16)[name = tensor<string, []>("reshape_0_cast_fp16")];
tensor<bool, []> matmul_0_transpose_x_0 = const()[name = tensor<string, []>("matmul_0_transpose_x_0"), val = tensor<bool, []>(false)];
tensor<bool, []> matmul_0_transpose_y_0 = const()[name = tensor<string, []>("matmul_0_transpose_y_0"), val = tensor<bool, []>(false)];
tensor<fp16, [4096, 16384]> transpose_1_to_fp16 = const()[name = tensor<string, []>("transpose_1_to_fp16"), val = tensor<fp16, [4096, 16384]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(8512)))];
tensor<fp16, [64, 16384]> matmul_0_cast_fp16 = matmul(transpose_x = matmul_0_transpose_x_0, transpose_y = matmul_0_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_1_to_fp16)[name = tensor<string, []>("matmul_0_cast_fp16")];
tensor<int32, [3]> concat_8 = const()[name = tensor<string, []>("concat_8"), val = tensor<int32, [3]>([1, 64, 16384])];
tensor<fp16, [1, 64, 16384]> reshape_2_cast_fp16 = reshape(shape = concat_8, x = matmul_0_cast_fp16)[name = tensor<string, []>("reshape_2_cast_fp16")];
tensor<bool, []> matmul_1_transpose_x_0 = const()[name = tensor<string, []>("matmul_1_transpose_x_0"), val = tensor<bool, []>(false)];
tensor<bool, []> matmul_1_transpose_y_0 = const()[name = tensor<string, []>("matmul_1_transpose_y_0"), val = tensor<bool, []>(false)];
tensor<fp16, [4096, 15616]> transpose_3_to_fp16 = const()[name = tensor<string, []>("transpose_3_to_fp16"), val = tensor<fp16, [4096, 15616]>(BLOBFILE(path = tensor<string, []>("@model_path/weights/weight.bin"), offset = tensor<uint64, []>(134226304)))];
tensor<fp16, [64, 15616]> matmul_1_cast_fp16 = matmul(transpose_x = matmul_1_transpose_x_0, transpose_y = matmul_1_transpose_y_0, x = reshape_0_cast_fp16, y = transpose_3_to_fp16)[name = tensor<string, []>("matmul_1_cast_fp16")];
tensor<int32, [3]> concat_16 = const()[name = tensor<string, []>("concat_16"), val = tensor<int32, [3]>([1, 64, 15616])];
tensor<fp16, [1, 64, 15616]> reshape_5_cast_fp16 = reshape(shape = concat_16, x = matmul_1_cast_fp16)[name = tensor<string, []>("reshape_5_cast_fp16")];
tensor<int32, []> var_69 = const()[name = tensor<string, []>("op_69"), val = tensor<int32, []>(-1)];
tensor<bool, []> var_70_interleave_0 = const()[name = tensor<string, []>("op_70_interleave_0"), val = tensor<bool, []>(false)];
tensor<fp16, [1, 64, 32000]> logits = concat(axis = var_69, interleave = var_70_interleave_0, values = (reshape_2_cast_fp16, reshape_5_cast_fp16))[name = tensor<string, []>("op_70_cast_fp16")];
} -> (logits);
}