From d88d2124b5fa0c8661c46c785acdd6f06f467420 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 4 Apr 2024 10:59:15 -0700 Subject: [PATCH] segfaut layer norm grad (#955) --- mlx/backend/metal/normalization.cpp | 2 +- python/tests/test_fast.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index b84c04f74..cd7ad7eac 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -392,7 +392,7 @@ void LayerNormVJP::eval_gpu( group_dims = MTL::Size(threadgroup_size, 1, 1); } - uint32_t w_stride = w.strides()[0]; + uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder->setComputePipelineState(kernel); set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0); set_array_buffer(compute_encoder, w, 1); diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index b144f3cd8..3b2db95c6 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -375,6 +375,17 @@ class TestFast(mlx_tests.MLXTestCase): self.assertLess(mx.abs(gb1).max(), 1e-9) self.assertLess(mx.abs(gb2).max(), 1e-9) + def test_layer_norm_grad_no_params(self): + eps = 1e-5 + f1 = lambda x: layer_norm(x, None, None, eps).sum() + f2 = lambda x: mx.fast.layer_norm(x, None, None, eps).sum() + x = mx.random.normal(shape=(2, 2, 8)) + mx.eval(x) + + gx1 = mx.grad(f1)(x) + gx2 = mx.grad(f2)(x) + self.assertTrue(mx.allclose(gx1, gx2, atol=1e-6)) + def test_layer_norm_grad_params(self): eps = 1e-5 f1 = lambda params, x: (layer_norm(x, params[0], params[1], eps)).sum()