From 110d9b149df91b819dcbeae86f1af1165171a429 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 1 Apr 2024 06:15:50 -0700 Subject: [PATCH] Layer norm grad fix donation bug (#941) * add layer norm grad test * Fix donation bug in layernorm vjp --------- Co-authored-by: Awni Hannun --- mlx/backend/metal/normalization.cpp | 9 ++++++++- python/tests/test_fast.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index df5301244..b84c04f74 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -355,7 +355,14 @@ void LayerNormVJP::eval_gpu( ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); strided_reduce_general_dispatch( - g, gb, "sum", plan, {0}, compute_encoder, d, s); + g_in_gx ? gx : (g_in_gw ? gw_temp : g), + gb, + "sum", + plan, + {0}, + compute_encoder, + d, + s); } const int simd_size = 32; diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index a0ea8e224..b144f3cd8 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -375,6 +375,21 @@ class TestFast(mlx_tests.MLXTestCase): self.assertLess(mx.abs(gb1).max(), 1e-9) self.assertLess(mx.abs(gb2).max(), 1e-9) + def test_layer_norm_grad_params(self): + eps = 1e-5 + f1 = lambda params, x: (layer_norm(x, params[0], params[1], eps)).sum() + f2 = lambda params, x: (mx.fast.layer_norm(x, params[0], params[1], eps)).sum() + + w = mx.ones((8,)) + b = mx.zeros((8,)) + x = mx.random.normal(shape=(2, 2, 8)) + mx.eval(x, w, b) + + gw1, gb1 = mx.grad(f1)((w, b), x) + gw2, gb2 = mx.grad(f2)((w, b), x) + self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5) + self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5) + def test_fast_transforms(self): x = mx.random.uniform(shape=(2, 2, 8))