mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix scatter_axis on single axis arrays
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
@@ -575,9 +576,17 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Set source info
|
// Set source info
|
||||||
|
if (ndim > 1) {
|
||||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||||
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||||
|
} else {
|
||||||
|
// The following will be ignored in the kernel but we still have to set
|
||||||
|
// some value so that metal validation passes.
|
||||||
|
compute_encoder.set_vector_bytes(idx.shape(), 3);
|
||||||
|
compute_encoder.set_vector_bytes(upd.strides(), 4);
|
||||||
|
compute_encoder.set_vector_bytes(idx.strides(), 5);
|
||||||
|
}
|
||||||
compute_encoder.set_bytes(ndim - 1, 6);
|
compute_encoder.set_bytes(ndim - 1, 6);
|
||||||
compute_encoder.set_bytes(axis_, 7);
|
compute_encoder.set_bytes(axis_, 7);
|
||||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||||
|
|||||||
Reference in New Issue
Block a user