mlx-examples/llms/mlx_lm/models/layers.py
Prince Canuma 76c3244cc5
Add support for Cohere's Command-R (#565)
* initial commit for command-R

* update mlp, layernorm, lm_head and model args

* add custom layernorm

* add default to tie_word_embeddings

* add layernorm weight type and refactor

* update layernorm (bias conditional) in model/layers

* fix layer norm use traditional rope

* add test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-13 07:03:36 -07:00

81 lines
2.1 KiB
Python

from functools import partial
import mlx.core as mx
import mlx.nn as nn
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return rms_norm(x, self.weight, self.eps)
@partial(mx.compile, shapeless=True)
def ln_norm(x, eps, weight=None, bias=None):
"""
Layer normalization for input tensor x.
Args:
x (np.ndarray): Input tensor.
eps (float, optional): Small value to avoid division by zero.
weight (np.ndarray, optional): Weight tensor for normalization.
bias (np.ndarray, optional): Bias tensor for normalization.
Returns:
np.ndarray: Normalized tensor.
"""
t = x.dtype
x = x.astype(mx.float32)
# Compute mean and variance along the last dimension
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
# Normalize the input tensor
x = (x - means) * mx.rsqrt(var + eps)
x = x.astype(t)
# Apply weight and bias if provided
if weight is not None:
x = x * weight
if bias is not None:
x = x + bias
return x
class LayerNorm(nn.Module):
def __init__(
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
):
super().__init__()
self.eps = eps
self.dims = dims
self.affine = affine
if affine:
self.weight = mx.ones((dims,))
self.bias = mx.zeros((dims,)) if bias else None
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x: mx.array) -> mx.array:
if self.affine:
if self.bias is not None:
return ln_norm(x, self.eps, self.weight, self.bias)
else:
return ln_norm(x, self.eps, self.weight)
else:
return ln_norm(x, self.eps)