diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index a956a6a03..9dabe1f3f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -60,7 +60,7 @@ void binary_op( break; } kname << op << type_to_name(a); - if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) { + if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { kname << "_" << shape.size(); } @@ -158,7 +158,7 @@ void binary_op( break; } kname << op << type_to_name(a); - if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) { + if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { kname << "_" << shape.size(); }