mlx-examples/llms/mlx_lm/models/openelm.py

230 lines
6.8 KiB
Python
Raw Normal View History

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
head_dim: int
num_transformer_layers: int
model_dim: int
vocab_size: int
ffn_dim_divisor: int
num_query_heads: List
num_kv_heads: List
ffn_multipliers: List
ffn_with_glu: bool = True
normalize_qk_projections: bool = True
share_input_output_layers: bool = True
rms_norm_eps: float = 1e-6
rope_freq_constant: float = 10000
def make_divisible(
v: Union[float, int],
divisor: Optional[int] = 8,
min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by the divisor
It can be seen at:
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
Args:
v: input value
divisor: default to 8
min_value: minimum divisor value
Returns:
new_v: new divisible value
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
self.head_dim = head_dim = args.head_dim
self.layer_id = layer_id
self.model_dim = model_dim = args.model_dim
self.n_heads = n_heads = args.num_query_heads[layer_id]
self.n_kv_heads = n_kv_heads = args.num_kv_heads[layer_id]
self.scale = head_dim**-0.5
op_size = (n_heads + (n_kv_heads * 2)) * head_dim
self.qkv_proj = nn.Linear(model_dim, op_size, bias=False)
self.out_proj = nn.Linear(n_heads * head_dim, model_dim, bias=False)
self.normalize_qk_projections = args.normalize_qk_projections
if self.normalize_qk_projections:
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.qkv_proj(x)
qkv = qkv.reshape(
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
).transpose(0, 2, 1, 3)
queries, keys, values = mx.split(
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
)
# Prepare the queries, keys and values for the attention computation
if self.normalize_qk_projections:
queries = self.q_norm(queries)
keys = self.k_norm(keys)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
self.args = args
dim = args.model_dim
ffn_multiplier = args.ffn_multipliers[layer_id]
intermediate_dim = int(
make_divisible(
ffn_multiplier * args.model_dim,
divisor=args.ffn_dim_divisor,
)
)
self.proj_1 = nn.Linear(dim, 2 * intermediate_dim, bias=False)
self.proj_2 = nn.Linear(intermediate_dim, dim, bias=False)
def __call__(self, x) -> mx.array:
x = self.proj_1(x)
gate, x = mx.split(x, 2, axis=-1)
return self.proj_2(nn.silu(gate) * x)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
dim = args.model_dim
self.attn = Attention(args, layer_id=layer_id)
self.ffn = MLP(args, layer_id=layer_id)
self.ffn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
self.attn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
r = self.ffn(self.ffn_norm(h))
out = h + r
return out
class OpenELMModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_transformer_layers = args.num_transformer_layers
assert self.vocab_size > 0
self.token_embeddings = nn.Embedding(args.vocab_size, args.model_dim)
self.layers = [
TransformerBlock(args, layer_id=layer_id)
for layer_id in range(self.num_transformer_layers)
]
self.norm = nn.RMSNorm(args.model_dim, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.token_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = OpenELMModel(args)
if not args.share_input_output_layers:
self.lm_head = nn.Linear(args.model_dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.transformer(inputs, cache)
if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_kv_heads