fix scatter + test (#1202)

* fix scatter + test

* fix test warnings

* fix metal validation
This commit is contained in:
Awni Hannun 2024-06-11 14:35:12 -07:00 committed by GitHub
parent 709ccc6800
commit df964132fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 54 additions and 9 deletions

View File

@ -293,7 +293,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
out.shape().data(), out.shape().size() * sizeof(int), 3); out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes( compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4); out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
size_t out_ndim = out.ndim();
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
if (upd_ndim <= 1) {
// Placeholder so Metal doesn't compalain
int shape_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 6);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
// Set index buffers // Set index buffers
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {

View File

@ -38,12 +38,24 @@ constexpr std::string_view scatter_kernels = R"(
device mlx_atomic<{1}>* out [[buffer(2)]], device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]], const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]], const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]], const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
{5} {5}
uint2 gid [[thread_position_in_grid]]) {{ uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }}; const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>( return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); updates,
out,
out_shape,
out_strides,
out_ndim,
upd_shape,
upd_ndim,
upd_size,
idx_buffers,
gid);
}} }}
[[kernel]] void scatter{0}_{4}( [[kernel]] void scatter{0}_{4}(

View File

@ -10,7 +10,10 @@ METAL_FUNC void scatter_1d_index_impl(
device mlx_atomic<T>* out [[buffer(2)]], device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]], const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]], const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]], const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
const thread array<const device IdxT*, NIDX>& idx_buffers, const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) { uint2 gid [[thread_position_in_grid]]) {
Op op; Op op;
@ -21,7 +24,14 @@ METAL_FUNC void scatter_1d_index_impl(
out_idx += idx_val * out_strides[i]; out_idx += idx_val * out_strides[i];
} }
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); if (upd_ndim > 1) {
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
out_idx += out_offset;
} else {
out_idx += gid.x;
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
} }
template <typename T, typename IdxT, typename Op, int NIDX> template <typename T, typename IdxT, typename Op, int NIDX>

View File

@ -49,7 +49,7 @@ class TestFFT(mlx_tests.MLXTestCase):
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
x = np.fft.rfft(a_np) x = np.fft.rfft(np.real(a_np))
self.check_mx_np(mx.fft.irfft, np.fft.irfft, x) self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
def test_fftn(self): def test_fftn(self):
@ -75,9 +75,9 @@ class TestFFT(mlx_tests.MLXTestCase):
if op in ["rfft2", "rfftn"]: if op in ["rfft2", "rfftn"]:
x = r x = r
elif op == "irfft2": elif op == "irfft2":
x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s)) x = np.ascontiguousarray(np.fft.rfft2(r, axes=ax, s=s))
elif op == "irfftn": elif op == "irfftn":
x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s)) x = np.ascontiguousarray(np.fft.rfftn(r, axes=ax, s=s))
mx_op = getattr(mx.fft, op) mx_op = getattr(mx.fft, op)
np_op = getattr(np.fft, op) np_op = getattr(np.fft, op)
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s) self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
@ -93,7 +93,7 @@ class TestFFT(mlx_tests.MLXTestCase):
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol) self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
ia_np = np.fft.rfft(a_np) ia_np = np.fft.rfft(r)
self.check_mx_np( self.check_mx_np(
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1] mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
) )

View File

@ -2062,6 +2062,18 @@ TEST_CASE("test scatter") {
auto out = scatter(dst, idx, src, 0); auto out = scatter(dst, idx, src, 0);
CHECK(array_equal(out, expected).item<bool>()); CHECK(array_equal(out, expected).item<bool>());
} }
// 1D indices with 2D update
{
auto dst = zeros({3, 4}, int32);
auto indices = {array({1}), array({2})};
auto axes = {0, 1};
auto updates = reshape(array({1, 2, 3, 4}, int32), {1, 2, 2});
auto out = scatter(dst, indices, updates, axes);
auto expected =
reshape(array({0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4}), {3, 4});
CHECK(array_equal(out, expected).item<bool>());
}
} }
TEST_CASE("test is positive infinity") { TEST_CASE("test is positive infinity") {