Fix slice data size (#1394)

* fix slice data size and add tests

* fix contiguous flag

* simplify stride and perform copy for non-contiguous arrays

* fix cpu

* comment
This commit is contained in:
Awni Hannun
2024-09-04 19:10:43 -07:00
committed by GitHub
parent 11371fe251
commit 7cca1727af
12 changed files with 129 additions and 39 deletions

View File

@@ -429,6 +429,14 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_slice_into_layer_norm(self):
dim = 128
eps = 1e-5
x = mx.random.uniform(shape=(8, 100, 128))[:, 99:]
rx_fast = mx.fast.layer_norm(x, weight=None, bias=None, eps=eps)
rx = layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-4)
def test_layer_norm_grad(self):
D = 32
eps = 1e-5