mlx-examples/llms/mlx_lm/models/phi3small.py
Awni Hannun fca087be49
More cache improvements (#1015)
* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-10-07 20:45:51 -07:00

311 lines
9.9 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
dense_attention_every_n_layers: int
ff_intermediate_size: int
gegelu_limit: float
num_hidden_layers: int
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: int = 64
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@partial(mx.compile, shapeless=True)
def gegelu_impl(a_gelu, a_linear, limit):
a_gelu = mx.where(
mx.isinf(a_gelu),
a_gelu,
mx.clip(a_gelu, a_min=None, a_max=limit),
)
a_linear = mx.where(
mx.isinf(a_linear),
a_linear,
mx.clip(a_linear, a_min=-limit, a_max=limit),
)
out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
return out_gelu * (a_linear + 1.0)
def gegelu(x, limit):
a_gelu, a_linear = x[..., ::2], x[..., 1::2]
return gegelu_impl(a_gelu, a_linear, limit)
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.query_key_value = nn.Linear(
dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim
)
self.dense = nn.Linear(dim, dim)
if args.mup_use_scaling:
norm_factor = head_dim / args.mup_attn_multiplier
else:
norm_factor = math.sqrt(head_dim)
self.scale = 1.0 / norm_factor
self.rope = nn.RoPE(
head_dim,
traditional=False,
base=args.rope_embedding_base,
scale=args.rope_position_scale,
)
if layer_idx % args.dense_attention_every_n_layers == 0:
self.block_sparse = True
self.blocksparse_block_size = args.blocksparse_block_size
if self.blocksparse_block_size not in (32, 64):
raise ValueError(
f"Unsupported block size {self.blocksparse_block_size}"
)
self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks
self.blocksparse_vert_stride = args.blocksparse_vert_stride
else:
self.block_sparse = False
def _block_sparse_mask(self, q_len, kv_len):
vert_stride = self.blocksparse_vert_stride
local_blocks = self.blocksparse_num_local_blocks
block_size = self.blocksparse_block_size
n_heads = self.n_heads
kv_blocks = (kv_len + block_size - 1) // block_size
q_blocks = (q_len + block_size - 1) // block_size
q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None]
k_pos = mx.arange(kv_blocks)[None, None]
mask_vert_strided = (
mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None]
) % vert_stride
mask_vert_strided = (mask_vert_strided == 0)[:, None, :]
block_mask = (q_pos >= k_pos) & (
(q_pos - k_pos < local_blocks) | mask_vert_strided
)
block_mask = block_mask.reshape(
self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:]
)
dense_mask = mx.repeat(
mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2
)
return block_mask, dense_mask[..., -q_len:, :kv_len]
def _block_sparse_attention(self, queries, keys, values, scale, mask):
queries = scale * queries
B = queries.shape[0]
L = queries.shape[2]
queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1))
keys = mx.expand_dims(keys, 2)
values = mx.expand_dims(values, 2)
# TODO get rid of dense mask if we have a fill value
block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2])
scores = queries @ mx.swapaxes(keys, -1, -2)
# TODO, uncomment when faster
# scores = mx.block_masked_mm(
# queries,
# mx.swapaxes(keys, -1, -2),
# mask_out=block_mask,
# block_size=self.blocksparse_block_size,
# )
if mask is not None:
scores = scores + mask
scores = scores + mx.where(
dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype)
)
scores = mx.softmax(scores, axis=-1, precise=True)
output = scores @ values
# TODO, uncomment when faster
# output = mx.block_masked_mm(
# scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size
# )
return mx.reshape(output, (B, self.n_heads, L, -1))
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.query_key_value(x)
qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim)
queries = qkv[..., :-2, :].flatten(-3, -2)
keys = qkv[..., -2, :]
values = qkv[..., -1, :]
# Prepare the queries, keys and values for the attention computation
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
if self.block_sparse:
output = self._block_sparse_attention(
queries, keys, values, scale=self.scale, mask=mask
)
else:
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output)
class MLP(nn.Module):
def __init__(self, args):
super().__init__()
dim = args.hidden_size
hidden_dim = args.ff_intermediate_size
self.gegelu_limit = args.gegelu_limit
self.up_proj = nn.Linear(dim, 2 * hidden_dim)
self.down_proj = nn.Linear(hidden_dim, dim)
def __call__(self, x) -> mx.array:
x = self.up_proj(x)
return self.down_proj(gegelu(x, self.gegelu_limit))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_epsilon
)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size,
eps=args.layer_norm_epsilon,
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Phi3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.mup_embedding_multiplier = args.mup_embedding_multiplier
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=l)
for l in range(args.num_hidden_layers)
]
self.final_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_epsilon
)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.final_layernorm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Phi3Model(args)
self.args = args
self.mup_width_multiplier = args.mup_width_multiplier
self._dummy_tokenizer_ids = mx.array(
[100256, 100258, 100259, 100260, 100264, 100265]
+ list(range(100267, 100352))
)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
if self.mup_width_multiplier:
out = out / self.mup_width_multiplier
out[self._dummy_tokenizer_ids] = -float("inf")
return out
@property
def layers(self):
return self.model.layers
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}