Fix strided sort bug (#1236)

* Use output strides in sort kernel

* fix zero strides bug
This commit is contained in:
Alex Barron
2024-06-26 14:32:11 -07:00
committed by GitHub
parent 5b0af4cdb1
commit 2615660e62
7 changed files with 222 additions and 262 deletions

View File

@@ -3,7 +3,7 @@
import math
import os
import unittest
from itertools import permutations
from itertools import permutations, product
import mlx.core as mx
import mlx_tests
@@ -1751,60 +1751,93 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
def test_sort(self):
shape = (3, 4, 5)
for dtype in ("int32", "float32"):
for axis in (None, 0, 1, 2):
with self.subTest(dtype=dtype, axis=axis):
np.random.seed(0)
np_dtype = getattr(np, dtype)
a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)
a_mx = mx.array(a_np)
shape = (6, 4, 10)
tests = product(
("int32", "float32"), # type
(None, 0, 1, 2), # axis
(True, False), # strided
)
for dtype, axis, strided in tests:
with self.subTest(dtype=dtype, axis=axis, strided=strided):
np.random.seed(0)
np_dtype = getattr(np, dtype)
a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)
a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[::2, :, ::2]
a_np = a_np[::2, :, ::2]
b_np = np.sort(a_np, axis=axis)
b_mx = mx.sort(a_mx, axis=axis)
b_np = np.sort(a_np, axis=axis)
b_mx = mx.sort(a_mx, axis=axis)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
c_np = np.argsort(a_np, axis=axis)
c_mx = mx.argsort(a_mx, axis=axis)
d_np = np.take_along_axis(a_np, c_np, axis=axis)
d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis)
c_np = np.argsort(a_np, axis=axis)
c_mx = mx.argsort(a_mx, axis=axis)
d_np = np.take_along_axis(a_np, c_np, axis=axis)
d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis)
self.assertTrue(np.array_equal(d_np, d_mx))
self.assertEqual(c_mx.dtype, mx.uint32)
self.assertTrue(np.array_equal(d_np, d_mx))
self.assertEqual(c_mx.dtype, mx.uint32)
# Set random seed
np.random.seed(0)
# Test multi-block sort
a_np = np.random.normal(size=(32769,)).astype(np.float32)
for strided in (False, True):
with self.subTest(strided=strided):
a_np = np.random.normal(size=(32769,)).astype(np.float32)
a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[::3]
a_np = a_np[::3]
b_np = np.sort(a_np)
b_mx = mx.sort(a_mx)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
# Test multi-dum multi-block sort
a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)
a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[..., ::3]
a_np = a_np[..., ::3]
b_np = np.sort(a_np, axis=-1)
b_mx = mx.sort(a_mx, axis=-1)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)
a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[:, ::3]
a_np = a_np[:, ::3]
b_np = np.sort(a_np, axis=1)
b_mx = mx.sort(a_mx, axis=1)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
# test 0 strides
a_np = np.array([1, 0, 2, 1, 3, 0, 4, 0])
a_mx = mx.array(a_np)
b_np = np.sort(a_np)
b_mx = mx.sort(a_mx)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
# Test multi-dum multi-block sort
a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)
a_mx = mx.array(a_np)
b_np = np.sort(a_np, axis=-1)
b_mx = mx.sort(a_mx, axis=-1)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)
a_mx = mx.array(a_np)
b_np = np.sort(a_np, axis=1)
b_mx = mx.sort(a_mx, axis=1)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
b_np = np.broadcast_to(a_np, (16, 8))
b_mx = mx.broadcast_to(a_mx, (16, 8))
mx.eval(b_mx)
for axis in (0, 1):
c_np = np.sort(b_np, axis=axis)
c_mx = mx.sort(b_mx, axis=axis)
self.assertTrue(np.array_equal(c_np, c_mx))
self.assertEqual(b_mx.dtype, c_mx.dtype)
def test_partition(self):
shape = (3, 4, 5)