program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})] { func main(tensor logits) [FlexibleShapeInformation = tuple, dict, tensor>>, tuple, dict, dict, tensor>>>>((("DefaultShapes", {{"logits", [1, 511, 32000]}}), ("EnumeratedShapes", {{"logits_1_1_1_1_32000_", {{"logits", [1, 1, 32000]}}}, {"logits_1_1_1_2_32000_", {{"logits", [1, 2, 32000]}}}, {"logits_1_1_1_4_32000_", {{"logits", [1, 4, 32000]}}}, {"logits_1_1_1_511_32000_", {{"logits", [1, 511, 32000]}}}, {"logits_1_1_1_512_32000_", {{"logits", [1, 512, 32000]}}}, {"logits_1_1_1_64_32000_", {{"logits", [1, 64, 32000]}}}})))] { tensor var_2 = const()[name = tensor("op_2"), val = tensor(-1)]; tensor var_3 = const()[name = tensor("op_3"), val = tensor(false)]; tensor argmax = reduce_argmax(axis = var_2, keep_dims = var_3, x = logits)[name = tensor("op_4_cast_fp16")]; } -> (argmax); }