Fast RMS Norm (#862)

* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-03-21 07:20:54 -07:00
committed by GitHub
parent 4650d94d98
commit a54f06b16f
17 changed files with 493 additions and 41 deletions

View File

@@ -117,6 +117,8 @@ class RMSNorm(Module):
where :math:`\gamma` is a learned per feature dimension parameter initialized at
1.
Note the accumulation for the mean is done in 32-bit precision.
[1]: https://arxiv.org/abs/1910.07467
Args:
@@ -133,18 +135,7 @@ class RMSNorm(Module):
return f"{self.weight.shape[0]}, eps={self.eps}"
def __call__(self, x):
# S is 1/sqrt(N) where N is the size of the features of x and is used
# to compute a numerically more stable RMS of x by multiplying with S
# first and summing.
#
# This way we prefer underflow over overflow which is controlled with
# the parameter epsilon anyway.
S = 1 / x.shape[-1] ** 0.5
n = (x * S).square().sum(axis=-1, keepdims=True)
n = mx.rsqrt(n + self.eps)
return self.weight * x * n
return mx.fast.rms_norm(x, self.weight, self.eps)
class GroupNorm(Module):

View File

@@ -15,6 +15,37 @@ void init_fast(nb::module_& parent_module) {
auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
m.def(
"rms_norm",
[](const array& x,
const array& weight,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::rms_norm(x, weight, eps, s);
},
"x"_a,
"weight"_a,
"eps"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Root Mean Square normalization (RMS norm).
The normalization is with respect to the last axis of the input ``x``.
Args:
x (array): Input array.
weight (array): A multiplicative weight to scale the result by.
The ``weight`` should be one-dimensional with the same size
as the last axis of ``x``.
eps (float): A small additive constant for numerical stability.
Returns:
array: The output array.
)pbdoc");
m.def(
"rope",
[](const array& a,

View File

@@ -115,6 +115,57 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_rms_norm(self):
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
# Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
dtypes = [mx.float32, mx.float16, mx.bfloat16]
epss = [1e-3, 1e-5]
dimss = [31, 32, 33]
defaults = (mx.float32, 1e-5, 32)
for dtype in dtypes:
_, eps, dims = defaults
x = mx.random.uniform(
shape=(
2,
dims,
)
).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for eps in epss:
dtype, _, dims = defaults
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for dims in dimss:
dtype, eps, _ = defaults
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
# Test > 4096
dims, dtype, eps = 4099, mx.float32, 1e-5
x = mx.random.uniform(shape=(dims,)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
def test_fast_transforms(self):
x = mx.random.uniform(shape=(2, 2, 8))