mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
Layer norm grad fix donation bug (#941)
* add layer norm grad test * Fix donation bug in layernorm vjp --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
9cbff5ec1d
commit
110d9b149d
@ -355,7 +355,14 @@ void LayerNormVJP::eval_gpu(
|
|||||||
ReductionPlan plan(
|
ReductionPlan plan(
|
||||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||||
strided_reduce_general_dispatch(
|
strided_reduce_general_dispatch(
|
||||||
g, gb, "sum", plan, {0}, compute_encoder, d, s);
|
g_in_gx ? gx : (g_in_gw ? gw_temp : g),
|
||||||
|
gb,
|
||||||
|
"sum",
|
||||||
|
plan,
|
||||||
|
{0},
|
||||||
|
compute_encoder,
|
||||||
|
d,
|
||||||
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int simd_size = 32;
|
const int simd_size = 32;
|
||||||
|
@ -375,6 +375,21 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
self.assertLess(mx.abs(gb1).max(), 1e-9)
|
self.assertLess(mx.abs(gb1).max(), 1e-9)
|
||||||
self.assertLess(mx.abs(gb2).max(), 1e-9)
|
self.assertLess(mx.abs(gb2).max(), 1e-9)
|
||||||
|
|
||||||
|
def test_layer_norm_grad_params(self):
|
||||||
|
eps = 1e-5
|
||||||
|
f1 = lambda params, x: (layer_norm(x, params[0], params[1], eps)).sum()
|
||||||
|
f2 = lambda params, x: (mx.fast.layer_norm(x, params[0], params[1], eps)).sum()
|
||||||
|
|
||||||
|
w = mx.ones((8,))
|
||||||
|
b = mx.zeros((8,))
|
||||||
|
x = mx.random.normal(shape=(2, 2, 8))
|
||||||
|
mx.eval(x, w, b)
|
||||||
|
|
||||||
|
gw1, gb1 = mx.grad(f1)((w, b), x)
|
||||||
|
gw2, gb2 = mx.grad(f2)((w, b), x)
|
||||||
|
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||||
|
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
|
||||||
|
|
||||||
def test_fast_transforms(self):
|
def test_fast_transforms(self):
|
||||||
x = mx.random.uniform(shape=(2, 2, 8))
|
x = mx.random.uniform(shape=(2, 2, 8))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user