From 01a29b51c802ac0b2d153e410aab724a0c9ecaaa Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 2 May 2025 20:25:28 -0700 Subject: [PATCH] fix --- mlx/backend/metal/binary.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 5943c5b27..c3c67e4d5 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -145,10 +145,10 @@ void binary_op_gpu_inplace( MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims; if (large) { - compute_encoder.set_bytes(a.size(), arg_idx++); + compute_encoder.set_bytes(out.data_size(), arg_idx++); grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); } else { - compute_encoder.set_bytes(a.size(), arg_idx++); + compute_encoder.set_bytes(out.data_size(), arg_idx++); grid_dims = MTL::Size(nthreads, 1, 1); } compute_encoder.dispatch_threads(grid_dims, group_dims);