mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
Fix slice data size (#1394)
* fix slice data size and add tests * fix contiguous flag * simplify stride and perform copy for non-contiguous arrays * fix cpu * comment
This commit is contained in:
@@ -429,6 +429,14 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_slice_into_layer_norm(self):
|
||||
dim = 128
|
||||
eps = 1e-5
|
||||
x = mx.random.uniform(shape=(8, 100, 128))[:, 99:]
|
||||
rx_fast = mx.fast.layer_norm(x, weight=None, bias=None, eps=eps)
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-4)
|
||||
|
||||
def test_layer_norm_grad(self):
|
||||
D = 32
|
||||
eps = 1e-5
|
||||
|
Reference in New Issue
Block a user