mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Reduce specializations (#1607)
* start of reduce specializations * fix all reduce * fix many dims * fix * non-jit tests clear * cleanup instantiations * cpu merges * change dim specializations * optimize * fix jit * fix jit * use higher precision for integer sum+prod * fixes
This commit is contained in:
@@ -131,6 +131,28 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
mxsum = y.sum().item()
|
||||
self.assertEqual(npsum, mxsum)
|
||||
|
||||
def test_many_reduction_axes(self):
|
||||
|
||||
def check(x, axes):
|
||||
expected = x
|
||||
for ax in axes:
|
||||
expected = mx.sum(expected, axis=ax, keepdims=True)
|
||||
out = mx.sum(x, axis=axes, keepdims=True)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4))
|
||||
check(x, (0, 2, 4))
|
||||
|
||||
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4))
|
||||
check(x, (0, 2, 4, 6))
|
||||
|
||||
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4))
|
||||
check(x, (0, 2, 4, 6, 8))
|
||||
|
||||
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4, 128))
|
||||
x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9)
|
||||
check(x, (1, 3, 5, 7, 9))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Reference in New Issue
Block a user