mlx-examples/llms/mlx_lm/models/cohere.py
Awni Hannun c386dd5f5a
Fix for cohere plus (#650)
* fix for cohere plus

* version bump
2024-04-05 14:11:24 -07:00

195 lines
5.9 KiB
Python

from dataclasses import dataclass
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 8192
num_hidden_layers: int = 40
intermediate_size: int = 22528
num_attention_heads: int = 64
num_key_value_heads: int = 64
rope_theta: float = 8000000.0
vocab_size: int = 256000
layer_norm_eps: float = 1e-05
logit_scale: float = 0.0625
attention_bias: bool = False
layer_norm_bias: bool = False
use_qk_norm: bool = False
class LayerNorm2D(nn.Module):
def __init__(self, d1, d2, eps):
super().__init__()
self.weight = mx.zeros((d1, d2))
self.eps = eps
def __call__(self, x):
return self.weight * mx.fast.layer_norm(x, None, None, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // args.num_attention_heads
self.scale = head_dim**-0.5
attetion_bias = args.attention_bias
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
self.use_qk_norm = args.use_qk_norm
if self.use_qk_norm:
self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps)
self.k_norm = LayerNorm2D(
self.n_kv_heads, head_dim, eps=args.layer_norm_eps
)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1)
keys = keys.reshape(B, L, self.n_kv_heads, -1)
if self.use_qk_norm:
queries = self.q_norm(queries)
keys = self.k_norm(keys)
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h, cache = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x, cache
class CohereModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = CohereModel(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T
out = out * self.model.args.logit_scale
return out, cache
@property
def layers(self):
return self.model.layers