mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
segfaut layer norm grad (#955)
This commit is contained in:
parent
e142aaf8a1
commit
d88d2124b5
@ -392,7 +392,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t w_stride = w.strides()[0];
|
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||||
set_array_buffer(compute_encoder, w, 1);
|
set_array_buffer(compute_encoder, w, 1);
|
||||||
|
@ -375,6 +375,17 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
self.assertLess(mx.abs(gb1).max(), 1e-9)
|
self.assertLess(mx.abs(gb1).max(), 1e-9)
|
||||||
self.assertLess(mx.abs(gb2).max(), 1e-9)
|
self.assertLess(mx.abs(gb2).max(), 1e-9)
|
||||||
|
|
||||||
|
def test_layer_norm_grad_no_params(self):
|
||||||
|
eps = 1e-5
|
||||||
|
f1 = lambda x: layer_norm(x, None, None, eps).sum()
|
||||||
|
f2 = lambda x: mx.fast.layer_norm(x, None, None, eps).sum()
|
||||||
|
x = mx.random.normal(shape=(2, 2, 8))
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
gx1 = mx.grad(f1)(x)
|
||||||
|
gx2 = mx.grad(f2)(x)
|
||||||
|
self.assertTrue(mx.allclose(gx1, gx2, atol=1e-6))
|
||||||
|
|
||||||
def test_layer_norm_grad_params(self):
|
def test_layer_norm_grad_params(self):
|
||||||
eps = 1e-5
|
eps = 1e-5
|
||||||
f1 = lambda params, x: (layer_norm(x, params[0], params[1], eps)).sum()
|
f1 = lambda params, x: (layer_norm(x, params[0], params[1], eps)).sum()
|
||||||
|
Loading…
Reference in New Issue
Block a user