diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index d2a601b1e..b01eeec7e 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" @@ -575,9 +576,17 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); - compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); - compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); + if (ndim > 1) { + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); + } else { + // The following will be ignored in the kernel but we still have to set + // some value so that metal validation passes. + compute_encoder.set_vector_bytes(idx.shape(), 3); + compute_encoder.set_vector_bytes(upd.strides(), 4); + compute_encoder.set_vector_bytes(idx.strides(), 5); + } compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(out.shape(axis_), 8);