mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix scatter + test (#1202)
* fix scatter + test * fix test warnings * fix metal validation
This commit is contained in:
parent
709ccc6800
commit
df964132fb
@ -293,7 +293,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
|
||||
if (upd_ndim <= 1) {
|
||||
// Placeholder so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
|
@ -38,12 +38,24 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
|
||||
updates,
|
||||
out,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
upd_shape,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
idx_buffers,
|
||||
gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
|
@ -10,7 +10,10 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
@ -21,7 +24,14 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
if (upd_ndim > 1) {
|
||||
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
} else {
|
||||
out_idx += gid.x;
|
||||
}
|
||||
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
|
@ -49,7 +49,7 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80)
|
||||
self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120)
|
||||
|
||||
x = np.fft.rfft(a_np)
|
||||
x = np.fft.rfft(np.real(a_np))
|
||||
self.check_mx_np(mx.fft.irfft, np.fft.irfft, x)
|
||||
|
||||
def test_fftn(self):
|
||||
@ -75,9 +75,9 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
if op in ["rfft2", "rfftn"]:
|
||||
x = r
|
||||
elif op == "irfft2":
|
||||
x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s))
|
||||
x = np.ascontiguousarray(np.fft.rfft2(r, axes=ax, s=s))
|
||||
elif op == "irfftn":
|
||||
x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s))
|
||||
x = np.ascontiguousarray(np.fft.rfftn(r, axes=ax, s=s))
|
||||
mx_op = getattr(mx.fft, op)
|
||||
np_op = getattr(np.fft, op)
|
||||
self.check_mx_np(mx_op, np_op, x, axes=ax, s=s)
|
||||
@ -93,7 +93,7 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
|
||||
self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol)
|
||||
|
||||
ia_np = np.fft.rfft(a_np)
|
||||
ia_np = np.fft.rfft(r)
|
||||
self.check_mx_np(
|
||||
mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1]
|
||||
)
|
||||
|
@ -2062,6 +2062,18 @@ TEST_CASE("test scatter") {
|
||||
auto out = scatter(dst, idx, src, 0);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
// 1D indices with 2D update
|
||||
{
|
||||
auto dst = zeros({3, 4}, int32);
|
||||
auto indices = {array({1}), array({2})};
|
||||
auto axes = {0, 1};
|
||||
auto updates = reshape(array({1, 2, 3, 4}, int32), {1, 2, 2});
|
||||
auto out = scatter(dst, indices, updates, axes);
|
||||
auto expected =
|
||||
reshape(array({0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4}), {3, 4});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test is positive infinity") {
|
||||
|
Loading…
Reference in New Issue
Block a user