fix scatter (#821)

This commit is contained in:
Awni Hannun 2024-03-12 11:42:07 -07:00 committed by GitHub
parent 366478c560
commit 8b7532b9ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 13 deletions

View File

@ -201,15 +201,12 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
if (index_nd1_specialization) {
bool upd_col_contiguous = upd.flags().col_contiguous;
compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {

View File

@ -20,7 +20,6 @@ METAL_FUNC void scatter_1d_index_impl(
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const constant bool& upd_col_contiguous [[buffer(6)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
@ -33,11 +32,7 @@ METAL_FUNC void scatter_1d_index_impl(
out_idx += idx_val * out_strides[i];
}
if (!upd_col_contiguous) {
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
} else {
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
}
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
@ -48,7 +43,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
@ -60,7 +54,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
out_shape, \
out_strides, \
upd_size, \
upd_col_contiguous, \
idx_buffers, \
gid); \
\
@ -195,7 +188,6 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cmath>
#include <numeric>
@ -1919,6 +1919,16 @@ TEST_CASE("test scatter") {
inds = array({0, 1});
out = scatter_add(in, inds, updates, 0);
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
// 1D scatter
{
auto dst = zeros({2, 4}, int32);
auto src = reshape(array({1, 2, 3, 4}), {1, 1, 4});
auto idx = array({1});
auto expected = reshape(array({0, 0, 0, 0, 1, 2, 3, 4}), {2, 4});
auto out = scatter(dst, idx, src, 0);
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test is positive infinity") {