diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 5943c5b27..c3c67e4d5 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -145,10 +145,10 @@ void binary_op_gpu_inplace( MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims; if (large) { - compute_encoder.set_bytes(a.size(), arg_idx++); + compute_encoder.set_bytes(out.data_size(), arg_idx++); grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); } else { - compute_encoder.set_bytes(a.size(), arg_idx++); + compute_encoder.set_bytes(out.data_size(), arg_idx++); grid_dims = MTL::Size(nthreads, 1, 1); } compute_encoder.dispatch_threads(grid_dims, group_dims);