From 65d0b8df9fb02a30445fc5431aa2d8e73ca0af2f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 29 Jan 2024 19:36:17 -0800 Subject: [PATCH] Fix binary op dispatch (#584) --- mlx/backend/metal/primitives.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index a956a6a03..9dabe1f3f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -60,7 +60,7 @@ void binary_op( break; } kname << op << type_to_name(a); - if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) { + if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { kname << "_" << shape.size(); } @@ -158,7 +158,7 @@ void binary_op( break; } kname << op << type_to_name(a); - if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) { + if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { kname << "_" << shape.size(); }