mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Check edge case handling in row reduce med kernel (#858)
This commit is contained in:
@@ -590,6 +590,18 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape))
|
||||
self.assertTrue(np.all(sum_npy == sum_mlx))
|
||||
|
||||
x_npy = (
|
||||
np.arange(3 * 2 * 3 * 3 * 3 * 3)
|
||||
.reshape(3, 2, 3, 3, 3, 3)
|
||||
.astype(np.float32)
|
||||
)
|
||||
x_mlx = mx.array(x_npy)
|
||||
|
||||
y_mlx = x_mlx.sum(axis=(0, 1, 3, 4, 5))
|
||||
y_npy = x_npy.sum(axis=(0, 1, 3, 4, 5))
|
||||
|
||||
self.assertTrue(np.array_equal(y_mlx, y_npy))
|
||||
|
||||
def test_prod(self):
|
||||
x = mx.array(
|
||||
[
|
||||
|
Reference in New Issue
Block a user