Fast RMS Norm (#862)

* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-03-21 07:20:54 -07:00
committed by GitHub
parent 4650d94d98
commit a54f06b16f
17 changed files with 493 additions and 41 deletions

View File

@@ -117,6 +117,8 @@ class RMSNorm(Module):
where :math:`\gamma` is a learned per feature dimension parameter initialized at
1.
Note the accumulation for the mean is done in 32-bit precision.
[1]: https://arxiv.org/abs/1910.07467
Args:
@@ -133,18 +135,7 @@ class RMSNorm(Module):
return f"{self.weight.shape[0]}, eps={self.eps}"
def __call__(self, x):
# S is 1/sqrt(N) where N is the size of the features of x and is used
# to compute a numerically more stable RMS of x by multiplying with S
# first and summing.
#
# This way we prefer underflow over overflow which is controlled with
# the parameter epsilon anyway.
S = 1 / x.shape[-1] ** 0.5
n = (x * S).square().sum(axis=-1, keepdims=True)
n = mx.rsqrt(n + self.eps)
return self.weight * x * n
return mx.fast.rms_norm(x, self.weight, self.eps)
class GroupNorm(Module):