mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fast RMS Norm (#862)
* fast rmsnorm * no rms gpu * kernel * fix shared mem * looped rms and donation in softmax * Make the squaring in float32 to avoid underflow * Fix the default StreamOrDevice for rope and rms_norm in fast * nits --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -117,6 +117,8 @@ class RMSNorm(Module):
|
||||
where :math:`\gamma` is a learned per feature dimension parameter initialized at
|
||||
1.
|
||||
|
||||
Note the accumulation for the mean is done in 32-bit precision.
|
||||
|
||||
[1]: https://arxiv.org/abs/1910.07467
|
||||
|
||||
Args:
|
||||
@@ -133,18 +135,7 @@ class RMSNorm(Module):
|
||||
return f"{self.weight.shape[0]}, eps={self.eps}"
|
||||
|
||||
def __call__(self, x):
|
||||
# S is 1/sqrt(N) where N is the size of the features of x and is used
|
||||
# to compute a numerically more stable RMS of x by multiplying with S
|
||||
# first and summing.
|
||||
#
|
||||
# This way we prefer underflow over overflow which is controlled with
|
||||
# the parameter epsilon anyway.
|
||||
S = 1 / x.shape[-1] ** 0.5
|
||||
|
||||
n = (x * S).square().sum(axis=-1, keepdims=True)
|
||||
n = mx.rsqrt(n + self.eps)
|
||||
|
||||
return self.weight * x * n
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class GroupNorm(Module):
|
||||
|
@@ -15,6 +15,37 @@ void init_fast(nb::module_& parent_module) {
|
||||
auto m =
|
||||
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
||||
|
||||
m.def(
|
||||
"rms_norm",
|
||||
[](const array& x,
|
||||
const array& weight,
|
||||
float eps,
|
||||
const StreamOrDevice& s /* = {} */) {
|
||||
return fast::rms_norm(x, weight, eps, s);
|
||||
},
|
||||
"x"_a,
|
||||
"weight"_a,
|
||||
"eps"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Root Mean Square normalization (RMS norm).
|
||||
|
||||
The normalization is with respect to the last axis of the input ``x``.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
weight (array): A multiplicative weight to scale the result by.
|
||||
The ``weight`` should be one-dimensional with the same size
|
||||
as the last axis of ``x``.
|
||||
eps (float): A small additive constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"rope",
|
||||
[](const array& a,
|
||||
|
@@ -115,6 +115,57 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_rms_norm(self):
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return weight * x.astype(weight.dtype)
|
||||
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||
|
||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||
epss = [1e-3, 1e-5]
|
||||
dimss = [31, 32, 33]
|
||||
defaults = (mx.float32, 1e-5, 32)
|
||||
|
||||
for dtype in dtypes:
|
||||
_, eps, dims = defaults
|
||||
x = mx.random.uniform(
|
||||
shape=(
|
||||
2,
|
||||
dims,
|
||||
)
|
||||
).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for eps in epss:
|
||||
dtype, _, dims = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for dims in dimss:
|
||||
dtype, eps, _ = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
# Test > 4096
|
||||
dims, dtype, eps = 4099, mx.float32, 1e-5
|
||||
x = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||
|
||||
def test_fast_transforms(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
|
||||
|
Reference in New Issue
Block a user