program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})] { func main(tensor logits_0, tensor logits_1, tensor logits_2, tensor logits_3, tensor logits_4, tensor logits_5, tensor logits_6, tensor logits_7) { tensor chunk_size = const()[name = tensor("chunk_size"), val = tensor([16384])]; tensor var_12 = const()[name = tensor("op_12"), val = tensor(1)]; tensor var_16_axis_0 = const()[name = tensor("op_16_axis_0"), val = tensor(-1)]; tensor var_16_ascending_0 = const()[name = tensor("op_16_ascending_0"), val = tensor(false)]; tensor var_16_sort_0 = const()[name = tensor("op_16_sort_0"), val = tensor(false)]; tensor var_16_return_indices_0 = const()[name = tensor("op_16_return_indices_0"), val = tensor(true)]; tensor var_16_cast_fp16_0, tensor var_16_cast_fp16_1 = topk(ascending = var_16_ascending_0, axis = var_16_axis_0, k = var_12, return_indices = var_16_return_indices_0, sort = var_16_sort_0, x = logits_0)[name = tensor("op_16_cast_fp16")]; tensor var_22 = const()[name = tensor("op_22"), val = tensor(1)]; tensor var_26_axis_0 = const()[name = tensor("op_26_axis_0"), val = tensor(-1)]; tensor var_26_ascending_0 = const()[name = tensor("op_26_ascending_0"), val = tensor(false)]; tensor var_26_sort_0 = const()[name = tensor("op_26_sort_0"), val = tensor(false)]; tensor var_26_return_indices_0 = const()[name = tensor("op_26_return_indices_0"), val = tensor(true)]; tensor var_26_cast_fp16_0, tensor var_26_cast_fp16_1 = topk(ascending = var_26_ascending_0, axis = var_26_axis_0, k = var_22, return_indices = var_26_return_indices_0, sort = var_26_sort_0, x = logits_1)[name = tensor("op_26_cast_fp16")]; tensor var_31 = add(x = var_26_cast_fp16_1, y = chunk_size)[name = tensor("op_31")]; tensor var_32 = const()[name = tensor("op_32"), val = tensor(1)]; tensor var_36_axis_0 = const()[name = tensor("op_36_axis_0"), val = tensor(-1)]; tensor var_36_ascending_0 = const()[name = tensor("op_36_ascending_0"), val = tensor(false)]; tensor var_36_sort_0 = const()[name = tensor("op_36_sort_0"), val = tensor(false)]; tensor var_36_return_indices_0 = const()[name = tensor("op_36_return_indices_0"), val = tensor(true)]; tensor var_36_cast_fp16_0, tensor var_36_cast_fp16_1 = topk(ascending = var_36_ascending_0, axis = var_36_axis_0, k = var_32, return_indices = var_36_return_indices_0, sort = var_36_sort_0, x = logits_2)[name = tensor("op_36_cast_fp16")]; tensor var_39 = const()[name = tensor("op_39"), val = tensor([32768])]; tensor var_41 = add(x = var_36_cast_fp16_1, y = var_39)[name = tensor("op_41")]; tensor var_42 = const()[name = tensor("op_42"), val = tensor(1)]; tensor var_46_axis_0 = const()[name = tensor("op_46_axis_0"), val = tensor(-1)]; tensor var_46_ascending_0 = const()[name = tensor("op_46_ascending_0"), val = tensor(false)]; tensor var_46_sort_0 = const()[name = tensor("op_46_sort_0"), val = tensor(false)]; tensor var_46_return_indices_0 = const()[name = tensor("op_46_return_indices_0"), val = tensor(true)]; tensor var_46_cast_fp16_0, tensor var_46_cast_fp16_1 = topk(ascending = var_46_ascending_0, axis = var_46_axis_0, k = var_42, return_indices = var_46_return_indices_0, sort = var_46_sort_0, x = logits_3)[name = tensor("op_46_cast_fp16")]; tensor var_49 = const()[name = tensor("op_49"), val = tensor([49152])]; tensor var_51 = add(x = var_46_cast_fp16_1, y = var_49)[name = tensor("op_51")]; tensor var_52 = const()[name = tensor("op_52"), val = tensor(1)]; tensor var_56_axis_0 = const()[name = tensor("op_56_axis_0"), val = tensor(-1)]; tensor var_56_ascending_0 = const()[name = tensor("op_56_ascending_0"), val = tensor(false)]; tensor var_56_sort_0 = const()[name = tensor("op_56_sort_0"), val = tensor(false)]; tensor var_56_return_indices_0 = const()[name = tensor("op_56_return_indices_0"), val = tensor(true)]; tensor var_56_cast_fp16_0, tensor var_56_cast_fp16_1 = topk(ascending = var_56_ascending_0, axis = var_56_axis_0, k = var_52, return_indices = var_56_return_indices_0, sort = var_56_sort_0, x = logits_4)[name = tensor("op_56_cast_fp16")]; tensor var_59 = const()[name = tensor("op_59"), val = tensor([65536])]; tensor var_61 = add(x = var_56_cast_fp16_1, y = var_59)[name = tensor("op_61")]; tensor var_62 = const()[name = tensor("op_62"), val = tensor(1)]; tensor var_66_axis_0 = const()[name = tensor("op_66_axis_0"), val = tensor(-1)]; tensor var_66_ascending_0 = const()[name = tensor("op_66_ascending_0"), val = tensor(false)]; tensor var_66_sort_0 = const()[name = tensor("op_66_sort_0"), val = tensor(false)]; tensor var_66_return_indices_0 = const()[name = tensor("op_66_return_indices_0"), val = tensor(true)]; tensor var_66_cast_fp16_0, tensor var_66_cast_fp16_1 = topk(ascending = var_66_ascending_0, axis = var_66_axis_0, k = var_62, return_indices = var_66_return_indices_0, sort = var_66_sort_0, x = logits_5)[name = tensor("op_66_cast_fp16")]; tensor var_69 = const()[name = tensor("op_69"), val = tensor([81920])]; tensor var_71 = add(x = var_66_cast_fp16_1, y = var_69)[name = tensor("op_71")]; tensor var_72 = const()[name = tensor("op_72"), val = tensor(1)]; tensor var_76_axis_0 = const()[name = tensor("op_76_axis_0"), val = tensor(-1)]; tensor var_76_ascending_0 = const()[name = tensor("op_76_ascending_0"), val = tensor(false)]; tensor var_76_sort_0 = const()[name = tensor("op_76_sort_0"), val = tensor(false)]; tensor var_76_return_indices_0 = const()[name = tensor("op_76_return_indices_0"), val = tensor(true)]; tensor var_76_cast_fp16_0, tensor var_76_cast_fp16_1 = topk(ascending = var_76_ascending_0, axis = var_76_axis_0, k = var_72, return_indices = var_76_return_indices_0, sort = var_76_sort_0, x = logits_6)[name = tensor("op_76_cast_fp16")]; tensor var_79 = const()[name = tensor("op_79"), val = tensor([98304])]; tensor var_81 = add(x = var_76_cast_fp16_1, y = var_79)[name = tensor("op_81")]; tensor var_82 = const()[name = tensor("op_82"), val = tensor(1)]; tensor cv_axis_0 = const()[name = tensor("cv_axis_0"), val = tensor(-1)]; tensor cv_ascending_0 = const()[name = tensor("cv_ascending_0"), val = tensor(false)]; tensor cv_sort_0 = const()[name = tensor("cv_sort_0"), val = tensor(false)]; tensor cv_return_indices_0 = const()[name = tensor("cv_return_indices_0"), val = tensor(true)]; tensor cv_cast_fp16_0, tensor cv_cast_fp16_1 = topk(ascending = cv_ascending_0, axis = cv_axis_0, k = var_82, return_indices = cv_return_indices_0, sort = cv_sort_0, x = logits_7)[name = tensor("cv_cast_fp16")]; tensor var_89 = const()[name = tensor("op_89"), val = tensor([114688])]; tensor var_91 = add(x = cv_cast_fp16_1, y = var_89)[name = tensor("op_91")]; tensor var_93 = const()[name = tensor("op_93"), val = tensor(-1)]; tensor values_interleave_0 = const()[name = tensor("values_interleave_0"), val = tensor(false)]; tensor values_cast_fp16 = concat(axis = var_93, interleave = values_interleave_0, values = (var_16_cast_fp16_0, var_26_cast_fp16_0, var_36_cast_fp16_0, var_46_cast_fp16_0, var_56_cast_fp16_0, var_66_cast_fp16_0, var_76_cast_fp16_0, cv_cast_fp16_0))[name = tensor("values_cast_fp16")]; tensor var_96 = const()[name = tensor("op_96"), val = tensor(-1)]; tensor indices_interleave_0 = const()[name = tensor("indices_interleave_0"), val = tensor(false)]; tensor indices = concat(axis = var_96, interleave = indices_interleave_0, values = (var_16_cast_fp16_1, var_31, var_41, var_51, var_61, var_71, var_81, var_91))[name = tensor("indices")]; tensor var_98 = const()[name = tensor("op_98"), val = tensor(1)]; tensor var_102_axis_0 = const()[name = tensor("op_102_axis_0"), val = tensor(-1)]; tensor var_102_ascending_0 = const()[name = tensor("op_102_ascending_0"), val = tensor(false)]; tensor var_102_sort_0 = const()[name = tensor("op_102_sort_0"), val = tensor(true)]; tensor var_102_return_indices_0 = const()[name = tensor("op_102_return_indices_0"), val = tensor(true)]; tensor var_102_cast_fp16_0, tensor var_102_cast_fp16_1 = topk(ascending = var_102_ascending_0, axis = var_102_axis_0, k = var_98, return_indices = var_102_return_indices_0, sort = var_102_sort_0, x = values_cast_fp16)[name = tensor("op_102_cast_fp16")]; tensor var_104 = const()[name = tensor("op_104"), val = tensor(-1)]; tensor var_106 = gather_along_axis(axis = var_104, indices = var_102_cast_fp16_1, x = indices)[name = tensor("op_106")]; tensor var_108_axes_0 = const()[name = tensor("op_108_axes_0"), val = tensor([-1])]; tensor argmax = squeeze(axes = var_108_axes_0, x = var_106)[name = tensor("op_108")]; } -> (argmax); }