smpanaro's picture
Update sequoia mode with transposed value cache and 4:508 input:cache length
722eedf verified
program(1.3)
[buildInfo = dict<string, string>({{"coremlc-component-MIL", "3400.42.1"}, {"coremlc-version", "3400.51.1"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})]
{
func main<ios16>(tensor<fp16, [1, ?, 32000]> logits) [FlexibleShapeInformation = tuple<tuple<string, dict<string, tensor<int32, [?]>>>, tuple<string, dict<string, dict<string, tensor<int32, [?]>>>>>((("DefaultShapes", {{"logits", [1, 511, 32000]}}), ("EnumeratedShapes", {{"logits_1_1_1_1_32000_", {{"logits", [1, 1, 32000]}}}, {"logits_1_1_1_2_32000_", {{"logits", [1, 2, 32000]}}}, {"logits_1_1_1_4_32000_", {{"logits", [1, 4, 32000]}}}, {"logits_1_1_1_511_32000_", {{"logits", [1, 511, 32000]}}}, {"logits_1_1_1_512_32000_", {{"logits", [1, 512, 32000]}}}, {"logits_1_1_1_64_32000_", {{"logits", [1, 64, 32000]}}}})))] {
int32 var_2 = const()[name = string("op_2"), val = int32(-1)];
bool var_3 = const()[name = string("op_3"), val = bool(false)];
tensor<int32, [1, ?]> argmax = reduce_argmax(axis = var_2, keep_dims = var_3, x = logits)[name = string("op_4_cast_fp16")];
} -> (argmax);
}