fix scatter (#821)

This commit is contained in:
Awni Hannun 2024-03-12 11:42:07 -07:00 committed by GitHub
parent 366478c560
commit 8b7532b9ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 13 deletions

View File

@ -201,15 +201,12 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
for (int i = idx_ndim; i < upd.ndim(); ++i) { for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i); upd_size *= upd.shape(i);
} }
if (index_nd1_specialization) { if (index_nd1_specialization) {
bool upd_col_contiguous = upd.flags().col_contiguous;
compute_encoder->setBytes( compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3); out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes( compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4); out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5); compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
// Set index buffers // Set index buffers
for (int i = 1; i < nidx + 1; ++i) { for (int i = 1; i < nidx + 1; ++i) {

View File

@ -20,7 +20,6 @@ METAL_FUNC void scatter_1d_index_impl(
const constant int* out_shape [[buffer(3)]], const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]], const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]], const constant size_t& upd_size [[buffer(5)]],
const constant bool& upd_col_contiguous [[buffer(6)]],
const thread array<const device IdxT*, NIDX>& idx_buffers, const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) { uint2 gid [[thread_position_in_grid]]) {
@ -33,11 +32,7 @@ METAL_FUNC void scatter_1d_index_impl(
out_idx += idx_val * out_strides[i]; out_idx += idx_val * out_strides[i];
} }
if (!upd_col_contiguous) {
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
} else {
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
}
} }
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \ #define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
@ -48,7 +43,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
const constant int* out_shape [[buffer(3)]], \ const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \ const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \ const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(IdxT) \ IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \ uint2 gid [[thread_position_in_grid]]) { \
\ \
@ -60,7 +54,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
out_shape, \ out_shape, \
out_strides, \ out_strides, \
upd_size, \ upd_size, \
upd_col_contiguous, \
idx_buffers, \ idx_buffers, \
gid); \ gid); \
\ \
@ -195,7 +188,6 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
const constant int* out_shape [[buffer(3)]], \ const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \ const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \ const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(idx_t) \ IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]); uint2 gid [[thread_position_in_grid]]);

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
@ -1919,6 +1919,16 @@ TEST_CASE("test scatter") {
inds = array({0, 1}); inds = array({0, 1});
out = scatter_add(in, inds, updates, 0); out = scatter_add(in, inds, updates, 0);
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>()); CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
// 1D scatter
{
auto dst = zeros({2, 4}, int32);
auto src = reshape(array({1, 2, 3, 4}), {1, 1, 4});
auto idx = array({1});
auto expected = reshape(array({0, 0, 0, 0, 1, 2, 3, 4}), {2, 4});
auto out = scatter(dst, idx, src, 0);
CHECK(array_equal(out, expected).item<bool>());
}
} }
TEST_CASE("test is positive infinity") { TEST_CASE("test is positive infinity") {