mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00

* initial commit for command-R * update mlp, layernorm, lm_head and model args * add custom layernorm * add default to tie_word_embeddings * add layernorm weight type and refactor * update layernorm (bias conditional) in model/layers * fix layer norm use traditional rope * add test --------- Co-authored-by: Awni Hannun <awni@apple.com>
81 lines
2.1 KiB
Python
81 lines
2.1 KiB
Python
from functools import partial
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
|
|
@partial(mx.compile, shapeless=True)
|
|
def rms_norm(x, weight, eps):
|
|
x = x.astype(mx.float32)
|
|
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
|
return weight * x.astype(weight.dtype)
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dims: int, eps: float = 1e-5):
|
|
super().__init__()
|
|
self.weight = mx.ones((dims,))
|
|
self.eps = eps
|
|
|
|
def __call__(self, x):
|
|
return rms_norm(x, self.weight, self.eps)
|
|
|
|
|
|
@partial(mx.compile, shapeless=True)
|
|
def ln_norm(x, eps, weight=None, bias=None):
|
|
"""
|
|
Layer normalization for input tensor x.
|
|
|
|
Args:
|
|
x (np.ndarray): Input tensor.
|
|
eps (float, optional): Small value to avoid division by zero.
|
|
weight (np.ndarray, optional): Weight tensor for normalization.
|
|
bias (np.ndarray, optional): Bias tensor for normalization.
|
|
|
|
Returns:
|
|
np.ndarray: Normalized tensor.
|
|
"""
|
|
t = x.dtype
|
|
x = x.astype(mx.float32)
|
|
|
|
# Compute mean and variance along the last dimension
|
|
means = mx.mean(x, axis=-1, keepdims=True)
|
|
var = mx.var(x, axis=-1, keepdims=True)
|
|
|
|
# Normalize the input tensor
|
|
x = (x - means) * mx.rsqrt(var + eps)
|
|
x = x.astype(t)
|
|
|
|
# Apply weight and bias if provided
|
|
if weight is not None:
|
|
x = x * weight
|
|
if bias is not None:
|
|
x = x + bias
|
|
return x
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
def __init__(
|
|
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
|
|
):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.dims = dims
|
|
self.affine = affine
|
|
|
|
if affine:
|
|
self.weight = mx.ones((dims,))
|
|
self.bias = mx.zeros((dims,)) if bias else None
|
|
|
|
def _extra_repr(self):
|
|
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
if self.affine:
|
|
if self.bias is not None:
|
|
return ln_norm(x, self.eps, self.weight, self.bias)
|
|
else:
|
|
return ln_norm(x, self.eps, self.weight)
|
|
else:
|
|
return ln_norm(x, self.eps)
|