mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix scatter + test (#1202)
* fix scatter + test * fix test warnings * fix metal validation
This commit is contained in:
parent
709ccc6800
commit
df964132fb
@ -293,7 +293,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||||
compute_encoder->setBytes(
|
compute_encoder->setBytes(
|
||||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
|
||||||
|
size_t out_ndim = out.ndim();
|
||||||
|
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
|
||||||
|
if (upd_ndim <= 1) {
|
||||||
|
// Placeholder so Metal doesn't compalain
|
||||||
|
int shape_ = 0;
|
||||||
|
compute_encoder->setBytes(&shape_, sizeof(int), 6);
|
||||||
|
} else {
|
||||||
|
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
|
||||||
|
}
|
||||||
|
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
|
||||||
|
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
|
||||||
|
|
||||||
// Set index buffers
|
// Set index buffers
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
@ -38,12 +38,24 @@ constexpr std::string_view scatter_kernels = R"(
|
|||||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||||
const constant int* out_shape [[buffer(3)]],
|
const constant int* out_shape [[buffer(3)]],
|
||||||
const constant size_t* out_strides [[buffer(4)]],
|
const constant size_t* out_strides [[buffer(4)]],
|
||||||
const constant size_t& upd_size [[buffer(5)]],
|
const constant size_t& out_ndim [[buffer(5)]],
|
||||||
|
const constant int* upd_shape [[buffer(6)]],
|
||||||
|
const constant size_t& upd_ndim [[buffer(7)]],
|
||||||
|
const constant size_t& upd_size [[buffer(8)]],
|
||||||
{5}
|
{5}
|
||||||
uint2 gid [[thread_position_in_grid]]) {{
|
uint2 gid [[thread_position_in_grid]]) {{
|
||||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||||
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
|
updates,
|
||||||
|
out,
|
||||||
|
out_shape,
|
||||||
|
out_strides,
|
||||||
|
out_ndim,
|
||||||
|
upd_shape,
|
||||||
|
upd_ndim,
|
||||||
|
upd_size,
|
||||||
|
idx_buffers,
|
||||||
|
gid);
|
||||||
}}
|
}}
|
||||||
|
|
||||||
[[kernel]] void scatter{0}_{4}(
|
[[kernel]] void scatter{0}_{4}(
|
||||||
|
@ -10,7 +10,10 @@ METAL_FUNC void scatter_1d_index_impl(
|
|||||||
device mlx_atomic<T>* out [[buffer(2)]],
|
device mlx_atomic<T>* out [[buffer(2)]],
|
||||||
const constant int* out_shape [[buffer(3)]],
|
const constant int* out_shape [[buffer(3)]],
|
||||||
const constant size_t* out_strides [[buffer(4)]],
|
const constant size_t* out_strides [[buffer(4)]],
|
||||||
const constant size_t& upd_size [[buffer(5)]],
|
const constant size_t& out_ndim [[buffer(5)]],
|
||||||
|
const constant int* upd_shape [[buffer(6)]],
|
||||||
|
const constant size_t& upd_ndim [[buffer(7)]],
|
||||||
|
const constant size_t& upd_size [[buffer(8)]],
|
||||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
Op op;
|
Op op;
|
||||||
@ -21,7 +24,14 @@ METAL_FUNC void scatter_1d_index_impl(
|
|||||||
out_idx += idx_val * out_strides[i];
|
out_idx += idx_val * out_strides[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
if (upd_ndim > 1) {
|
||||||
|
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
|
||||||
|
out_idx += out_offset;
|
||||||
|
} else {
|
||||||
|
out_idx += gid.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||||
|
@ -49,7 +49,7 @@ class TestFFT(mlx_tests.MLXTestCase):
|
|||||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||||
|
|
||||||
x = np.fft.rfft(a_np)
|
x = np.fft.rfft(np.real(a_np))
|
||||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
|
self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
|
||||||
|
|
||||||
def test_fftn(self):
|
def test_fftn(self):
|
||||||
@ -75,9 +75,9 @@ class TestFFT(mlx_tests.MLXTestCase):
|
|||||||
if op in ["rfft2", "rfftn"]:
|
if op in ["rfft2", "rfftn"]:
|
||||||
x = r
|
x = r
|
||||||
elif op == "irfft2":
|
elif op == "irfft2":
|
||||||
x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s))
|
x = np.ascontiguousarray(np.fft.rfft2(r, axes=ax, s=s))
|
||||||
elif op == "irfftn":
|
elif op == "irfftn":
|
||||||
x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s))
|
x = np.ascontiguousarray(np.fft.rfftn(r, axes=ax, s=s))
|
||||||
mx_op = getattr(mx.fft, op)
|
mx_op = getattr(mx.fft, op)
|
||||||
np_op = getattr(np.fft, op)
|
np_op = getattr(np.fft, op)
|
||||||
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
||||||
@ -93,7 +93,7 @@ class TestFFT(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
|
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
ia_np = np.fft.rfft(a_np)
|
ia_np = np.fft.rfft(r)
|
||||||
self.check_mx_np(
|
self.check_mx_np(
|
||||||
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
|
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
|
||||||
)
|
)
|
||||||
|
@ -2062,6 +2062,18 @@ TEST_CASE("test scatter") {
|
|||||||
auto out = scatter(dst, idx, src, 0);
|
auto out = scatter(dst, idx, src, 0);
|
||||||
CHECK(array_equal(out, expected).item<bool>());
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 1D indices with 2D update
|
||||||
|
{
|
||||||
|
auto dst = zeros({3, 4}, int32);
|
||||||
|
auto indices = {array({1}), array({2})};
|
||||||
|
auto axes = {0, 1};
|
||||||
|
auto updates = reshape(array({1, 2, 3, 4}, int32), {1, 2, 2});
|
||||||
|
auto out = scatter(dst, indices, updates, axes);
|
||||||
|
auto expected =
|
||||||
|
reshape(array({0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4}), {3, 4});
|
||||||
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test is positive infinity") {
|
TEST_CASE("test is positive infinity") {
|
||||||
|
Loading…
Reference in New Issue
Block a user