From f2994d5b2955f5b9893f16e2e5eaa1b7ec9e7673 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 3 Jul 2025 14:50:23 -0700 Subject: [PATCH] Fix scatter_axis on single axis arrays --- mlx/backend/metal/indexing.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index d2a601b1e..b01eeec7e 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" @@ -575,9 +576,17 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); - compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); - compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); + if (ndim > 1) { + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); + } else { + // The following will be ignored in the kernel but we still have to set + // some value so that metal validation passes. + compute_encoder.set_vector_bytes(idx.shape(), 3); + compute_encoder.set_vector_bytes(upd.strides(), 4); + compute_encoder.set_vector_bytes(idx.strides(), 5); + } compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(out.shape(axis_), 8);