mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
@@ -23,13 +23,6 @@ class ModelArgs(BaseModelArgs):
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return (1.0 + weight) * x.astype(weight.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
@@ -37,7 +30,7 @@ class RMSNorm(nn.Module):
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
Reference in New Issue
Block a user