mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fix strided sort bug (#1236)
* Use output strides in sort kernel * fix zero strides bug
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
from itertools import permutations, product
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@@ -1751,60 +1751,93 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
|
||||
|
||||
def test_sort(self):
|
||||
shape = (3, 4, 5)
|
||||
for dtype in ("int32", "float32"):
|
||||
for axis in (None, 0, 1, 2):
|
||||
with self.subTest(dtype=dtype, axis=axis):
|
||||
np.random.seed(0)
|
||||
np_dtype = getattr(np, dtype)
|
||||
a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)
|
||||
a_mx = mx.array(a_np)
|
||||
shape = (6, 4, 10)
|
||||
tests = product(
|
||||
("int32", "float32"), # type
|
||||
(None, 0, 1, 2), # axis
|
||||
(True, False), # strided
|
||||
)
|
||||
for dtype, axis, strided in tests:
|
||||
with self.subTest(dtype=dtype, axis=axis, strided=strided):
|
||||
np.random.seed(0)
|
||||
np_dtype = getattr(np, dtype)
|
||||
a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)
|
||||
a_mx = mx.array(a_np)
|
||||
if strided:
|
||||
a_mx = a_mx[::2, :, ::2]
|
||||
a_np = a_np[::2, :, ::2]
|
||||
|
||||
b_np = np.sort(a_np, axis=axis)
|
||||
b_mx = mx.sort(a_mx, axis=axis)
|
||||
b_np = np.sort(a_np, axis=axis)
|
||||
b_mx = mx.sort(a_mx, axis=axis)
|
||||
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
c_np = np.argsort(a_np, axis=axis)
|
||||
c_mx = mx.argsort(a_mx, axis=axis)
|
||||
d_np = np.take_along_axis(a_np, c_np, axis=axis)
|
||||
d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis)
|
||||
c_np = np.argsort(a_np, axis=axis)
|
||||
c_mx = mx.argsort(a_mx, axis=axis)
|
||||
d_np = np.take_along_axis(a_np, c_np, axis=axis)
|
||||
d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis)
|
||||
|
||||
self.assertTrue(np.array_equal(d_np, d_mx))
|
||||
self.assertEqual(c_mx.dtype, mx.uint32)
|
||||
self.assertTrue(np.array_equal(d_np, d_mx))
|
||||
self.assertEqual(c_mx.dtype, mx.uint32)
|
||||
|
||||
# Set random seed
|
||||
np.random.seed(0)
|
||||
|
||||
# Test multi-block sort
|
||||
a_np = np.random.normal(size=(32769,)).astype(np.float32)
|
||||
for strided in (False, True):
|
||||
with self.subTest(strided=strided):
|
||||
a_np = np.random.normal(size=(32769,)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
if strided:
|
||||
a_mx = a_mx[::3]
|
||||
a_np = a_np[::3]
|
||||
|
||||
b_np = np.sort(a_np)
|
||||
b_mx = mx.sort(a_mx)
|
||||
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
# Test multi-dum multi-block sort
|
||||
a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
if strided:
|
||||
a_mx = a_mx[..., ::3]
|
||||
a_np = a_np[..., ::3]
|
||||
|
||||
b_np = np.sort(a_np, axis=-1)
|
||||
b_mx = mx.sort(a_mx, axis=-1)
|
||||
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
if strided:
|
||||
a_mx = a_mx[:, ::3]
|
||||
a_np = a_np[:, ::3]
|
||||
|
||||
b_np = np.sort(a_np, axis=1)
|
||||
b_mx = mx.sort(a_mx, axis=1)
|
||||
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
# test 0 strides
|
||||
a_np = np.array([1, 0, 2, 1, 3, 0, 4, 0])
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
b_np = np.sort(a_np)
|
||||
b_mx = mx.sort(a_mx)
|
||||
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
# Test multi-dum multi-block sort
|
||||
a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
b_np = np.sort(a_np, axis=-1)
|
||||
b_mx = mx.sort(a_mx, axis=-1)
|
||||
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
b_np = np.sort(a_np, axis=1)
|
||||
b_mx = mx.sort(a_mx, axis=1)
|
||||
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
b_np = np.broadcast_to(a_np, (16, 8))
|
||||
b_mx = mx.broadcast_to(a_mx, (16, 8))
|
||||
mx.eval(b_mx)
|
||||
for axis in (0, 1):
|
||||
c_np = np.sort(b_np, axis=axis)
|
||||
c_mx = mx.sort(b_mx, axis=axis)
|
||||
self.assertTrue(np.array_equal(c_np, c_mx))
|
||||
self.assertEqual(b_mx.dtype, c_mx.dtype)
|
||||
|
||||
def test_partition(self):
|
||||
shape = (3, 4, 5)
|
||||
|
Reference in New Issue
Block a user