mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

* fix rotating kv cache for chat use case * reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat * nit in chat * fix tests * fix tests * fix tests * docs * chat command * comments + docs * Define meta_state on all Cache implementations * fixes + trim_prompt_cache api * fix default model --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
311 lines
9.9 KiB
Python
311 lines
9.9 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
from typing import Any, Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .base import BaseModelArgs, create_attention_mask
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs(BaseModelArgs):
|
|
model_type: str
|
|
hidden_size: int
|
|
dense_attention_every_n_layers: int
|
|
ff_intermediate_size: int
|
|
gegelu_limit: float
|
|
num_hidden_layers: int
|
|
num_attention_heads: int
|
|
layer_norm_epsilon: float
|
|
vocab_size: int
|
|
num_key_value_heads: int
|
|
mup_attn_multiplier: float = 1.0
|
|
mup_use_scaling: bool = True
|
|
mup_embedding_multiplier: float = 10.0
|
|
mup_width_multiplier: float = 8.0
|
|
rope_embedding_base: float = 1000000
|
|
rope_position_scale: float = 1.0
|
|
blocksparse_block_size: int = 64
|
|
blocksparse_num_local_blocks: int = 16
|
|
blocksparse_vert_stride: int = 8
|
|
|
|
|
|
@partial(mx.compile, shapeless=True)
|
|
def gegelu_impl(a_gelu, a_linear, limit):
|
|
a_gelu = mx.where(
|
|
mx.isinf(a_gelu),
|
|
a_gelu,
|
|
mx.clip(a_gelu, a_min=None, a_max=limit),
|
|
)
|
|
a_linear = mx.where(
|
|
mx.isinf(a_linear),
|
|
a_linear,
|
|
mx.clip(a_linear, a_min=-limit, a_max=limit),
|
|
)
|
|
out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
|
|
return out_gelu * (a_linear + 1.0)
|
|
|
|
|
|
def gegelu(x, limit):
|
|
a_gelu, a_linear = x[..., ::2], x[..., 1::2]
|
|
return gegelu_impl(a_gelu, a_linear, limit)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, args: ModelArgs, layer_idx):
|
|
super().__init__()
|
|
|
|
dim = args.hidden_size
|
|
self.n_heads = n_heads = args.num_attention_heads
|
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
self.n_q_per_kv = n_heads // n_kv_heads
|
|
|
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
|
|
|
self.query_key_value = nn.Linear(
|
|
dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim
|
|
)
|
|
self.dense = nn.Linear(dim, dim)
|
|
|
|
if args.mup_use_scaling:
|
|
norm_factor = head_dim / args.mup_attn_multiplier
|
|
else:
|
|
norm_factor = math.sqrt(head_dim)
|
|
self.scale = 1.0 / norm_factor
|
|
|
|
self.rope = nn.RoPE(
|
|
head_dim,
|
|
traditional=False,
|
|
base=args.rope_embedding_base,
|
|
scale=args.rope_position_scale,
|
|
)
|
|
|
|
if layer_idx % args.dense_attention_every_n_layers == 0:
|
|
self.block_sparse = True
|
|
self.blocksparse_block_size = args.blocksparse_block_size
|
|
if self.blocksparse_block_size not in (32, 64):
|
|
raise ValueError(
|
|
f"Unsupported block size {self.blocksparse_block_size}"
|
|
)
|
|
self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks
|
|
self.blocksparse_vert_stride = args.blocksparse_vert_stride
|
|
else:
|
|
self.block_sparse = False
|
|
|
|
def _block_sparse_mask(self, q_len, kv_len):
|
|
vert_stride = self.blocksparse_vert_stride
|
|
local_blocks = self.blocksparse_num_local_blocks
|
|
block_size = self.blocksparse_block_size
|
|
n_heads = self.n_heads
|
|
|
|
kv_blocks = (kv_len + block_size - 1) // block_size
|
|
q_blocks = (q_len + block_size - 1) // block_size
|
|
q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None]
|
|
k_pos = mx.arange(kv_blocks)[None, None]
|
|
|
|
mask_vert_strided = (
|
|
mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None]
|
|
) % vert_stride
|
|
mask_vert_strided = (mask_vert_strided == 0)[:, None, :]
|
|
|
|
block_mask = (q_pos >= k_pos) & (
|
|
(q_pos - k_pos < local_blocks) | mask_vert_strided
|
|
)
|
|
block_mask = block_mask.reshape(
|
|
self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:]
|
|
)
|
|
dense_mask = mx.repeat(
|
|
mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2
|
|
)
|
|
return block_mask, dense_mask[..., -q_len:, :kv_len]
|
|
|
|
def _block_sparse_attention(self, queries, keys, values, scale, mask):
|
|
queries = scale * queries
|
|
B = queries.shape[0]
|
|
L = queries.shape[2]
|
|
queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1))
|
|
keys = mx.expand_dims(keys, 2)
|
|
values = mx.expand_dims(values, 2)
|
|
|
|
# TODO get rid of dense mask if we have a fill value
|
|
block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2])
|
|
scores = queries @ mx.swapaxes(keys, -1, -2)
|
|
# TODO, uncomment when faster
|
|
# scores = mx.block_masked_mm(
|
|
# queries,
|
|
# mx.swapaxes(keys, -1, -2),
|
|
# mask_out=block_mask,
|
|
# block_size=self.blocksparse_block_size,
|
|
# )
|
|
|
|
if mask is not None:
|
|
scores = scores + mask
|
|
scores = scores + mx.where(
|
|
dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype)
|
|
)
|
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
|
|
|
output = scores @ values
|
|
# TODO, uncomment when faster
|
|
# output = mx.block_masked_mm(
|
|
# scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size
|
|
# )
|
|
return mx.reshape(output, (B, self.n_heads, L, -1))
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Any] = None,
|
|
) -> mx.array:
|
|
B, L, D = x.shape
|
|
|
|
qkv = self.query_key_value(x)
|
|
qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim)
|
|
queries = qkv[..., :-2, :].flatten(-3, -2)
|
|
keys = qkv[..., -2, :]
|
|
values = qkv[..., -1, :]
|
|
|
|
# Prepare the queries, keys and values for the attention computation
|
|
queries = queries.transpose(0, 2, 1, 3)
|
|
keys = keys.transpose(0, 2, 1, 3)
|
|
values = values.transpose(0, 2, 1, 3)
|
|
|
|
if cache is not None:
|
|
queries = self.rope(queries, offset=cache.offset)
|
|
keys = self.rope(keys, offset=cache.offset)
|
|
keys, values = cache.update_and_fetch(keys, values)
|
|
else:
|
|
queries = self.rope(queries)
|
|
keys = self.rope(keys)
|
|
|
|
if self.block_sparse:
|
|
output = self._block_sparse_attention(
|
|
queries, keys, values, scale=self.scale, mask=mask
|
|
)
|
|
else:
|
|
output = mx.fast.scaled_dot_product_attention(
|
|
queries, keys, values, scale=self.scale, mask=mask
|
|
)
|
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
return self.dense(output)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
dim = args.hidden_size
|
|
hidden_dim = args.ff_intermediate_size
|
|
self.gegelu_limit = args.gegelu_limit
|
|
self.up_proj = nn.Linear(dim, 2 * hidden_dim)
|
|
self.down_proj = nn.Linear(hidden_dim, dim)
|
|
|
|
def __call__(self, x) -> mx.array:
|
|
x = self.up_proj(x)
|
|
return self.down_proj(gegelu(x, self.gegelu_limit))
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, args: ModelArgs, layer_idx):
|
|
super().__init__()
|
|
self.num_attention_heads = args.num_attention_heads
|
|
self.hidden_size = args.hidden_size
|
|
self.self_attn = Attention(args, layer_idx)
|
|
self.mlp = MLP(args)
|
|
self.input_layernorm = nn.LayerNorm(
|
|
args.hidden_size, eps=args.layer_norm_epsilon
|
|
)
|
|
self.post_attention_layernorm = nn.LayerNorm(
|
|
args.hidden_size,
|
|
eps=args.layer_norm_epsilon,
|
|
)
|
|
self.args = args
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Any] = None,
|
|
) -> mx.array:
|
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
|
h = x + r
|
|
r = self.mlp(self.post_attention_layernorm(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
class Phi3Model(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.vocab_size = args.vocab_size
|
|
self.num_hidden_layers = args.num_hidden_layers
|
|
assert self.vocab_size > 0
|
|
self.mup_embedding_multiplier = args.mup_embedding_multiplier
|
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
self.layers = [
|
|
TransformerBlock(args=args, layer_idx=l)
|
|
for l in range(args.num_hidden_layers)
|
|
]
|
|
self.final_layernorm = nn.LayerNorm(
|
|
args.hidden_size, eps=args.layer_norm_epsilon
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache=None,
|
|
):
|
|
h = self.embed_tokens(inputs)
|
|
if self.mup_embedding_multiplier:
|
|
h = self.mup_embedding_multiplier * h
|
|
|
|
mask = create_attention_mask(h, cache)
|
|
|
|
if cache is None:
|
|
cache = [None] * len(self.layers)
|
|
|
|
for layer, c in zip(self.layers, cache):
|
|
h = layer(h, mask, c)
|
|
|
|
return self.final_layernorm(h)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.model_type = args.model_type
|
|
self.model = Phi3Model(args)
|
|
self.args = args
|
|
self.mup_width_multiplier = args.mup_width_multiplier
|
|
self._dummy_tokenizer_ids = mx.array(
|
|
[100256, 100258, 100259, 100260, 100264, 100265]
|
|
+ list(range(100267, 100352))
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache=None,
|
|
):
|
|
out = self.model(inputs, cache)
|
|
out = self.model.embed_tokens.as_linear(out)
|
|
if self.mup_width_multiplier:
|
|
out = out / self.mup_width_multiplier
|
|
out[self._dummy_tokenizer_ids] = -float("inf")
|
|
return out
|
|
|
|
@property
|
|
def layers(self):
|
|
return self.model.layers
|
|
|
|
def sanitize(self, weights):
|
|
# Remove unused precomputed rotary freqs
|
|
return {
|
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
|
}
|