mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix scatter (#821)
This commit is contained in:
parent
366478c560
commit
8b7532b9ab
@ -201,15 +201,12 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
bool upd_col_contiguous = upd.flags().col_contiguous;
|
||||
compute_encoder->setBytes(
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
|
@ -20,7 +20,6 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
@ -33,11 +32,7 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
if (!upd_col_contiguous) {
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
} else {
|
||||
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
|
||||
}
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
}
|
||||
|
||||
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
||||
@ -48,7 +43,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
@ -60,7 +54,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
upd_size, \
|
||||
upd_col_contiguous, \
|
||||
idx_buffers, \
|
||||
gid); \
|
||||
\
|
||||
@ -195,7 +188,6 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
@ -1919,6 +1919,16 @@ TEST_CASE("test scatter") {
|
||||
inds = array({0, 1});
|
||||
out = scatter_add(in, inds, updates, 0);
|
||||
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
|
||||
|
||||
// 1D scatter
|
||||
{
|
||||
auto dst = zeros({2, 4}, int32);
|
||||
auto src = reshape(array({1, 2, 3, 4}), {1, 1, 4});
|
||||
auto idx = array({1});
|
||||
auto expected = reshape(array({0, 0, 0, 0, 1, 2, 3, 4}), {2, 4});
|
||||
auto out = scatter(dst, idx, src, 0);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test is positive infinity") {
|
||||
|
Loading…
Reference in New Issue
Block a user