program(1.3) [buildInfo = dict({{"coremlc-component-MIL", "3400.42.1"}, {"coremlc-version", "3400.51.1"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})] { func main(tensor logits) [FlexibleShapeInformation = tuple>>, tuple>>>>((("DefaultShapes", {{"logits", [1, 511, 32000]}}), ("EnumeratedShapes", {{"logits_1_1_1_1_32000_", {{"logits", [1, 1, 32000]}}}, {"logits_1_1_1_2_32000_", {{"logits", [1, 2, 32000]}}}, {"logits_1_1_1_4_32000_", {{"logits", [1, 4, 32000]}}}, {"logits_1_1_1_511_32000_", {{"logits", [1, 511, 32000]}}}, {"logits_1_1_1_512_32000_", {{"logits", [1, 512, 32000]}}}, {"logits_1_1_1_64_32000_", {{"logits", [1, 64, 32000]}}}})))] { int32 var_2 = const()[name = string("op_2"), val = int32(-1)]; bool var_3 = const()[name = string("op_3"), val = bool(false)]; tensor argmax = reduce_argmax(axis = var_2, keep_dims = var_3, x = logits)[name = string("op_4_cast_fp16")]; } -> (argmax); }