Allow no copy negative strides in as_strided and slice (#1688)

* allow no copy negative strides in as_strided and slice

* fix jit

* fix jit
This commit is contained in:
Awni Hannun
2024-12-12 08:59:45 -08:00
committed by GitHub
parent 4d595a2a39
commit 6bd28d246e
15 changed files with 133 additions and 163 deletions

View File

@@ -1758,6 +1758,10 @@ class TestOps(mlx_tests.MLXTestCase):
y_mlx = mx.as_strided(x_mlx, shape, stride, offset)
self.assertTrue(np.array_equal(y_npy, y_mlx))
x = mx.random.uniform(shape=(32,))
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
self.assertTrue(mx.array_equal(y, x[::-1]))
def test_scans(self):
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)