mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Adds mx.fast.layer_norm (#870)
This commit is contained in:

committed by
GitHub

parent
105d236889
commit
2225374060
@@ -85,13 +85,19 @@ class LayerNorm(Module):
|
||||
eps (float): A small additive constant for numerical stability
|
||||
affine (bool): If True learn an affine transform to apply after the
|
||||
normalization
|
||||
bias (bool): If True include a translation to the affine
|
||||
transformation. If set to False the transformation is not really affine
|
||||
just scaling.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
def __init__(
|
||||
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
if bias:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
@@ -99,10 +105,9 @@ class LayerNorm(Module):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
weight = self.weight if "weight" in self else None
|
||||
bias = self.bias if "bias" in self else None
|
||||
return mx.fast.layer_norm(x, weight, bias, self.eps)
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
|
@@ -46,6 +46,42 @@ void init_fast(nb::module_& parent_module) {
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"layer_norm",
|
||||
[](const array& x,
|
||||
const std::optional<array>& weight,
|
||||
const std::optional<array>& bias,
|
||||
float eps,
|
||||
const StreamOrDevice& s /* = {} */) {
|
||||
return fast::layer_norm(x, weight, bias, eps, s);
|
||||
},
|
||||
"x"_a,
|
||||
"weight"_a.none(),
|
||||
"bias"_a.none(),
|
||||
"eps"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Layer normalization.
|
||||
|
||||
The normalization is with respect to the last axis of the input ``x``.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
weight (array, optional): A multiplicative weight to scale the result by.
|
||||
The ``weight`` should be one-dimensional with the same size
|
||||
as the last axis of ``x``. If set to ``None`` then no scaling happens.
|
||||
bias (array, optional): An additive offset to be added to the result.
|
||||
The ``bias`` should be one-dimensional with the same size
|
||||
as the last axis of ``x``. If set to ``None`` then no translation happens.
|
||||
eps (float): A small additive constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"rope",
|
||||
[](const array& a,
|
||||
|
@@ -166,6 +166,105 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||
|
||||
def test_layer_norm(self):
|
||||
def layer_norm(x, weight, bias, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
mean = x.mean(axis=-1, keepdims=True)
|
||||
var = x.var(axis=-1, keepdims=True)
|
||||
x = (x - mean) * mx.rsqrt(var + eps)
|
||||
x = x.astype(ot)
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
return x
|
||||
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 2e-6, mx.float16: 2e-3, mx.bfloat16: 2e-2}
|
||||
|
||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||
epss = [1e-3, 1e-5]
|
||||
dimss = [31, 32, 33]
|
||||
defaults = (mx.float32, 1e-5, 32)
|
||||
|
||||
for dtype in dtypes:
|
||||
_, eps, dims = defaults
|
||||
x = mx.random.uniform(
|
||||
shape=(
|
||||
2,
|
||||
dims,
|
||||
)
|
||||
).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for eps in epss:
|
||||
dtype, _, dims = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for dims in dimss:
|
||||
dtype, eps, _ = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
# Test > 4096
|
||||
dims, dtype, eps = 4099, mx.float32, 1e-5
|
||||
x = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_fast_transforms(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
|
||||
|
Reference in New Issue
Block a user