mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 21:21:21 +08:00
parent
0fe6895893
commit
496315fe1d
@ -234,7 +234,7 @@ void scan_dispatch(
|
|||||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||||
auto init = (issubdtype(input.dtype(), floating))
|
auto init = (issubdtype(input.dtype(), floating))
|
||||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||||
: std::numeric_limits<U>::max();
|
: std::numeric_limits<U>::min();
|
||||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||||
|
@ -309,6 +309,7 @@ template <
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Share the prefix
|
// Share the prefix
|
||||||
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
||||||
|
@ -65,16 +65,16 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Compute the thread grid
|
// Compute the thread grid
|
||||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||||
int elements_per_simd = n_reads * 32;
|
constexpr int simd_size = 32;
|
||||||
|
int elements_per_simd = n_reads * simd_size;
|
||||||
int thread_groups = in.size() / size;
|
int thread_groups = in.size() / size;
|
||||||
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (size < n_reads * 1024) {
|
if (size <= n_reads * 1024) {
|
||||||
thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) *
|
|
||||||
elements_per_simd;
|
|
||||||
} else if (size < n_reads * 2048) {
|
|
||||||
thread_group_size =
|
thread_group_size =
|
||||||
((size / 2 + elements_per_simd - 1) / elements_per_simd) *
|
((size + elements_per_simd - 1) / elements_per_simd) * simd_size;
|
||||||
elements_per_simd;
|
} else if (size <= n_reads * 2048) {
|
||||||
|
thread_group_size =
|
||||||
|
((size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size;
|
||||||
}
|
}
|
||||||
thread_group_size = std::min(
|
thread_group_size = std::min(
|
||||||
thread_group_size,
|
thread_group_size,
|
||||||
|
@ -1678,7 +1678,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
c_mlx = mxop(a_mlx, axis=axis)
|
c_mlx = mxop(a_mlx, axis=axis)
|
||||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||||
|
|
||||||
|
a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
|
||||||
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
|
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
|
||||||
|
mxop = getattr(mx, op)
|
||||||
c1 = mxop(a_mlx, axis=2)
|
c1 = mxop(a_mlx, axis=2)
|
||||||
c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False)
|
c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False)
|
||||||
self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:]))
|
self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:]))
|
||||||
@ -1719,6 +1721,18 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out = mx.cumsum(a_t, axis=-1)
|
out = mx.cumsum(a_t, axis=-1)
|
||||||
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
|
expected = (mat_t * a_t[:, None, :]).sum(axis=-1)
|
||||||
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
|
self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3))
|
||||||
|
sizes = [1023, 1024, 1025, 2047, 2048, 2049]
|
||||||
|
for s in sizes:
|
||||||
|
a = mx.ones((s,), mx.int32)
|
||||||
|
out = mx.cumsum(a)
|
||||||
|
expected = mx.arange(1, s + 1, dtype=mx.int32)
|
||||||
|
self.assertTrue(mx.array_equal(expected, out))
|
||||||
|
|
||||||
|
# non-contiguous scan
|
||||||
|
a = mx.ones((s, 2), mx.int32)
|
||||||
|
out = mx.cumsum(a, axis=0)
|
||||||
|
expected = mx.repeat(expected[:, None], 2, axis=1)
|
||||||
|
self.assertTrue(mx.array_equal(expected, out))
|
||||||
|
|
||||||
def test_squeeze_expand(self):
|
def test_squeeze_expand(self):
|
||||||
a = mx.zeros((2, 1, 2, 1))
|
a = mx.zeros((2, 1, 2, 1))
|
||||||
|
Loading…
Reference in New Issue
Block a user