mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
Add support for Cohere's Command-R (#565)
* initial commit for command-R * update mlp, layernorm, lm_head and model args * add custom layernorm * add default to tie_word_embeddings * add layernorm weight type and refactor * update layernorm (bias conditional) in model/layers * fix layer norm use traditional rope * add test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
3535408c99
commit
76c3244cc5
171
llms/mlx_lm/models/cohere.py
Normal file
171
llms/mlx_lm/models/cohere.py
Normal file
@ -0,0 +1,171 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 8192
|
||||
num_hidden_layers: int = 40
|
||||
intermediate_size: int = 22528
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 64
|
||||
rope_theta: float = 8000000.0
|
||||
vocab_size: int = 256000
|
||||
layer_norm_eps: float = 1e-05
|
||||
logit_scale: float = 0.0625
|
||||
attention_bias: bool = False
|
||||
layer_norm_bias: bool = False
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // args.num_attention_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
attetion_bias = args.attention_bias
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
h = self.input_layernorm(x)
|
||||
attn_h, cache = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x, cache
|
||||
|
||||
|
||||
class CohereModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.norm(h), cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = CohereModel(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
out = out @ self.model.embed_tokens.weight.T
|
||||
out = out * self.model.args.logit_scale
|
||||
return out, cache
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
@ -23,29 +23,58 @@ class RMSNorm(nn.Module):
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def ln_norm(x, eps, weight=None, bias=None):
|
||||
"""
|
||||
Layer normalization for input tensor x.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Input tensor.
|
||||
eps (float, optional): Small value to avoid division by zero.
|
||||
weight (np.ndarray, optional): Weight tensor for normalization.
|
||||
bias (np.ndarray, optional): Bias tensor for normalization.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized tensor.
|
||||
"""
|
||||
t = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
|
||||
# Compute mean and variance along the last dimension
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
|
||||
# Normalize the input tensor
|
||||
x = (x - means) * mx.rsqrt(var + eps)
|
||||
x = x.astype(t)
|
||||
return weight * x + bias if weight is not None else x
|
||||
|
||||
# Apply weight and bias if provided
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
def __init__(
|
||||
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
self.affine = affine
|
||||
|
||||
if affine:
|
||||
self.weight = mx.ones((dims,))
|
||||
self.bias = mx.zeros((dims,)) if bias else None
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if "weight" in self:
|
||||
return ln_norm(x, self.eps, self.weight, self.bias)
|
||||
if self.affine:
|
||||
if self.bias is not None:
|
||||
return ln_norm(x, self.eps, self.weight, self.bias)
|
||||
else:
|
||||
return ln_norm(x, self.eps, self.weight)
|
||||
else:
|
||||
return ln_norm(x, self.eps)
|
||||
|
@ -254,7 +254,6 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
||||
|
||||
def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
|
||||
|
||||
from mlx_lm.models import starcoder2
|
||||
|
||||
args = starcoder2.ModelArgs(
|
||||
@ -276,6 +275,17 @@ class TestModels(unittest.TestCase):
|
||||
self.assertIn("lm_head.weight", sanitized_weights)
|
||||
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
||||
|
||||
def test_cohere(self):
|
||||
from mlx_lm.models import cohere
|
||||
|
||||
args = cohere.ModelArgs(
|
||||
model_type="cohere",
|
||||
)
|
||||
model = cohere.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user