From ee60e2a9d585788da30efa90326be9d2f1bceb97 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 8 May 2024 08:18:13 -0700 Subject: [PATCH] Kv cache (#643) * in place kv_cache * fix * fix kv cache size * partially fix kv cache dtype * step kv cache * multiple of step size * more teests + kv cache * more kv cache * udpate all models to use kv cache --- llms/mlx_lm/models/base.py | 31 +++++++++ llms/mlx_lm/models/cohere.py | 33 +++++---- llms/mlx_lm/models/dbrx.py | 36 ++++++---- llms/mlx_lm/models/gemma.py | 33 +++++---- llms/mlx_lm/models/llama.py | 33 +++++---- llms/mlx_lm/models/minicpm.py | 34 +++++---- llms/mlx_lm/models/mixtral.py | 33 +++++---- llms/mlx_lm/models/olmo.py | 29 +++++--- llms/mlx_lm/models/openelm.py | 41 +++++------ llms/mlx_lm/models/phi.py | 55 ++++++++------- llms/mlx_lm/models/phi3.py | 33 +++++---- llms/mlx_lm/models/phixtral.py | 43 +++++++----- llms/mlx_lm/models/plamo.py | 77 ++++++++++----------- llms/mlx_lm/models/qwen.py | 55 ++++++++------- llms/mlx_lm/models/qwen2.py | 32 +++++---- llms/mlx_lm/models/qwen2_moe.py | 32 +++++---- llms/mlx_lm/models/stablelm.py | 34 +++++---- llms/mlx_lm/models/starcoder2.py | 32 +++++---- llms/mlx_lm/tokenizer_utils.py | 3 +- llms/mlx_lm/utils.py | 16 +++-- llms/mlx_lm/version.py | 2 +- llms/tests/test_models.py | 115 ++++++++++++++++++++++++++++++- 22 files changed, 534 insertions(+), 298 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index d1ea0b2c..15002f8f 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -1,6 +1,37 @@ import inspect from dataclasses import dataclass +import mlx.core as mx + + +class KVCache: + + def __init__(self, head_dim, n_kv_heads): + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + + def update_and_fetch(self, keys, values): + prev = self.offset + if prev % self.step == 0: + n_steps = (self.step + keys.shape[2] - 1) // self.step + shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim) + new_k = mx.zeros(shape, keys.dtype) + new_v = mx.zeros(shape, values.dtype) + if self.keys is not None: + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + + self.offset += keys.shape[2] + self.keys[..., prev : self.offset, :] = keys + self.values[..., prev : self.offset, :] = values + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + @dataclass class BaseModelArgs: diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index dae61760..621a85a2 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -84,11 +84,9 @@ class Attention(nn.Module): values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -98,7 +96,7 @@ class Attention(nn.Module): ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class MLP(nn.Module): @@ -132,9 +130,9 @@ class TransformerBlock(nn.Module): cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: h = self.input_layernorm(x) - attn_h, cache = self.self_attn(h, mask, cache) + attn_h = self.self_attn(h, mask, cache) ff_h = self.mlp(h) - return attn_h + ff_h + x, cache + return attn_h + ff_h + x class CohereModel(nn.Module): @@ -167,10 +165,10 @@ class CohereModel(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -178,17 +176,26 @@ class Model(nn.Module): super().__init__() self.model_type = args.model_type self.model = CohereModel(args) + self.args = args def __call__( self, inputs: mx.array, cache=None, ): - out, cache = self.model(inputs, cache) + out = self.model(inputs, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale - return out, cache + return out @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index b0ea0efa..dc310ca4 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -65,11 +65,9 @@ class Attention(nn.Module): ) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -78,7 +76,7 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output), (keys, values) + return self.out_proj(output) class NormAttnNorm(nn.Module): @@ -94,9 +92,9 @@ class NormAttnNorm(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - h, cache = self.attn(self.norm_1(x), mask=mask, cache=cache) + h = self.attn(self.norm_1(x), mask=mask, cache=cache) x = h + x - return x, self.norm_2(x), cache + return x, self.norm_2(x) class MLP(nn.Module): @@ -181,9 +179,9 @@ class DecoderLayer(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, h, cache = self.norm_attn_norm(x, mask, cache) + r, h = self.norm_attn_norm(x, mask, cache) out = self.ffn(h) + r - return out, cache + return out class DBRX(nn.Module): @@ -210,10 +208,10 @@ class DBRX(nn.Module): if cache is None: cache = [None] * len(self.blocks) - for e, layer in enumerate(self.blocks): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.blocks, cache): + h = layer(h, mask, c) - return self.norm_f(h), cache + return self.norm_f(h) class Model(nn.Module): @@ -229,8 +227,8 @@ class Model(nn.Module): inputs: mx.array, cache=None, ): - out, cache = self.transformer(inputs, cache) - return self.lm_head(out), cache + out = self.transformer(inputs, cache) + return self.lm_head(out) @property def layers(self): @@ -253,3 +251,11 @@ class Model(nn.Module): experts = [(s, sv.T) for s, sv in experts] new_weights.update(experts) return new_weights + + @property + def head_dim(self): + return self.args.d_model // self.args.n_heads + + @property + def n_kv_heads(self): + return self.args.attn_config["kv_n_heads"] diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index ebd8f5e7..e48f1909 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -70,11 +70,9 @@ class Attention(nn.Module): values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -84,7 +82,7 @@ class Attention(nn.Module): ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class MLP(nn.Module): @@ -115,11 +113,11 @@ class TransformerBlock(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r - return out, cache + return out class GemmaModel(nn.Module): @@ -151,10 +149,10 @@ class GemmaModel(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -162,16 +160,25 @@ class Model(nn.Module): super().__init__() self.model_type = args.model_type self.model = GemmaModel(args) + self.args = args def __call__( self, inputs: mx.array, cache=None, ): - out, cache = self.model(inputs, cache) + out = self.model(inputs, cache) out = self.model.embed_tokens.as_linear(out) - return out, cache + return out @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.head_dim + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index cbb5f2cb..ada05d0f 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -78,11 +78,9 @@ class Attention(nn.Module): values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -91,7 +89,7 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class MLP(nn.Module): @@ -124,11 +122,11 @@ class TransformerBlock(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r - return out, cache + return out def create_additive_causal_mask(N: int, offset: int = 0): @@ -168,10 +166,10 @@ class LlamaModel(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -180,14 +178,15 @@ class Model(nn.Module): self.model_type = args.model_type self.model = LlamaModel(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.args = args def __call__( self, inputs: mx.array, cache=None, ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + out = self.model(inputs, cache) + return self.lm_head(out) def sanitize(self, weights): # Remove unused precomputed rotary freqs @@ -198,3 +197,11 @@ class Model(nn.Module): @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index d3119f71..dbfe4186 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -19,7 +19,6 @@ class ModelArgs(BaseModelArgs): rms_norm_eps: float vocab_size: int num_key_value_heads: int - max_position_embeddings: int scale_depth: float scale_emb: float rope_theta: float = 1000000.0 @@ -47,7 +46,6 @@ class Attention(nn.Module): self.hidden_size = args.hidden_size self.num_heads = n_heads = args.num_attention_heads self.rope_theta = args.rope_theta - self.max_position_embeddings = args.max_position_embeddings self.head_dim = head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 @@ -98,11 +96,9 @@ class Attention(nn.Module): ) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -113,7 +109,7 @@ class Attention(nn.Module): attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(attn_output), (keys, values) + return self.o_proj(attn_output) class DecoderLayer(nn.Module): @@ -139,11 +135,11 @@ class DecoderLayer(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) r = self.mlp(self.post_attention_layernorm(h)) out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) - return out, cache + return out class MiniCPMModel(nn.Module): @@ -172,10 +168,10 @@ class MiniCPMModel(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -193,14 +189,14 @@ class Model(nn.Module): inputs: mx.array, cache=None, ): - out, cache = self.model(inputs, cache) + out = self.model(inputs, cache) if not self.args.tie_word_embeddings: out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) else: out = out @ self.model.embed_tokens.weight.T - return out, cache + return out def sanitize(self, weights): if "lm_head.weight" not in weights: @@ -210,3 +206,11 @@ class Model(nn.Module): @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index d11a7507..7bf67638 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -77,11 +77,9 @@ class MixtralAttention(nn.Module): ) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -90,7 +88,7 @@ class MixtralAttention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class MixtralBLockSparseTop2MLP(nn.Module): @@ -180,11 +178,11 @@ class MixtralDecoderLayer(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.block_sparse_moe(self.post_attention_layernorm(h)) out = h + r - return out, cache + return out class MixtralModel(nn.Module): @@ -215,10 +213,10 @@ class MixtralModel(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -227,15 +225,24 @@ class Model(nn.Module): self.model_type = args.model_type self.model = MixtralModel(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.args = args def __call__( self, inputs: mx.array, cache=None, ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + out = self.model(inputs, cache) + return self.lm_head(out) @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index b84b2a38..120ea9b9 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -78,11 +78,9 @@ class TransformerBlock(nn.Module): values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -92,7 +90,7 @@ class TransformerBlock(nn.Module): scores += mask scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.attn_out(output), (keys, values) + return self.attn_out(output) def __call__( self, @@ -100,13 +98,13 @@ class TransformerBlock(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.attend(self.att_norm(x), mask, cache) + r = self.attend(self.att_norm(x), mask, cache) h = x + r x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1) out = h + self.ff_out(nn.silu(x2) * x1) - return out, cache + return out class Transformer(nn.Module): @@ -136,15 +134,15 @@ class Transformer(nn.Module): if cache is None: cache = [None] * len(self.blocks) - for e, block in enumerate(self.blocks): - h, cache[e] = block(h, mask, cache[e]) + for block, c in zip(self.blocks, cache): + h = block(h, mask, c) h = self.norm(h) if self.weight_tying: return self.wte.as_linear(h), cache - return self.ff_out(h), cache + return self.ff_out(h) class OlmoModel(nn.Module): @@ -165,6 +163,7 @@ class Model(nn.Module): super().__init__() self.model_type = args.model_type self.model = OlmoModel(args) + self.args = args def __call__( self, @@ -176,3 +175,11 @@ class Model(nn.Module): @property def layers(self): return self.model.transformer.blocks + + @property + def head_dim(self): + return self.args.d_model // self.args.n_heads + + @property + def n_kv_heads(self): + return self.args.n_heads diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 49e1c5d1..3fbdc58c 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -22,8 +22,7 @@ class ModelArgs(BaseModelArgs): normalize_qk_projections: bool = True share_input_output_layers: bool = True rms_norm_eps: float = 1e-6 - rope_theta: float = 10000 - rope_traditional: bool = False + rope_freq_constant: float = 10000 def make_divisible( @@ -73,9 +72,7 @@ class Attention(nn.Module): self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) - self.rope = nn.RoPE( - head_dim, traditional=args.rope_traditional, base=args.rope_theta - ) + self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant) def __call__( self, @@ -87,12 +84,10 @@ class Attention(nn.Module): qkv = self.qkv_proj(x) - # [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h] -> [B, (q_h + k_h + v_h), S, h] qkv = qkv.reshape( B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim ).transpose(0, 2, 1, 3) - # [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h] queries, keys, values = mx.split( qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1 ) @@ -103,11 +98,9 @@ class Attention(nn.Module): keys = self.k_norm(keys) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -118,7 +111,7 @@ class Attention(nn.Module): output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output), (keys, values) + return self.out_proj(output) class MLP(nn.Module): @@ -159,11 +152,11 @@ class TransformerBlock(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.attn(self.attn_norm(x), mask, cache) + r = self.attn(self.attn_norm(x), mask, cache) h = x + r r = self.ffn(self.ffn_norm(h)) out = h + r - return out, cache + return out class OpenELMModel(nn.Module): @@ -195,10 +188,10 @@ class OpenELMModel(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -215,14 +208,22 @@ class Model(nn.Module): inputs: mx.array, cache=None, ): - out, cache = self.transformer(inputs, cache) + out = self.transformer(inputs, cache) if self.args.share_input_output_layers: out = self.transformer.token_embeddings.as_linear(out) else: out = self.lm_head(out) - return out, cache + return out @property def layers(self): return self.transformer.layers + + @property + def head_dim(self): + return self.args.head_dim + + @property + def n_kv_heads(self): + return self.args.num_kv_heads diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 91b97023..8feaa23a 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -75,35 +75,29 @@ class PhiAttention(nn.Module): queries = queries.reshape( B, L, - n_kv_heads, - n_heads // n_kv_heads, + n_heads, -1, - ).moveaxis(1, 3) - keys = keys.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3) - values = values.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3) + ).moveaxis(1, 2) + keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2) + values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2) # Add RoPE to the queries and keys and combine them with the cache if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[-2]) - keys = self.rope(keys, offset=key_cache.shape[-2]) - keys = mx.concatenate([key_cache, keys], axis=-2) - values = mx.concatenate([value_cache, values], axis=-2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) - queries = queries.astype(mx.float32) - - # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.swapaxes(-1, -2) - if mask is not None: - scores = scores + mask - scores = mx.softmax(scores, axis=-1).astype(values.dtype) - output = (scores @ values).moveaxis(3, 1).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries.astype(mx.float32), keys, values, scale=scale, mask=mask + ).astype(values.dtype) - return self.dense(output), (keys, values) + output = output.moveaxis(2, 1).reshape(B, L, -1) + + return self.dense(output) class PhiMLP(nn.Module): @@ -128,9 +122,9 @@ class PhiDecoderLayer(nn.Module): def __call__(self, x, mask, cache): h = self.input_layernorm(x) - attn_h, cache = self.self_attn(h, mask, cache) + attn_h = self.self_attn(h, mask, cache) ff_h = self.mlp(h) - return attn_h + ff_h + x, cache + return attn_h + ff_h + x class PhiModel(nn.Module): @@ -152,9 +146,9 @@ class PhiModel(nn.Module): mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(x.dtype) - for e, layer in enumerate(self.layers): - x, cache[e] = layer(x, mask, cache[e]) - return self.final_layernorm(x), cache + for layer, c in zip(self.layers, cache): + x = layer(x, mask, c) + return self.final_layernorm(x) class Model(nn.Module): @@ -163,15 +157,24 @@ class Model(nn.Module): self.model_type = config.model_type self.model = PhiModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) + self.args = config def __call__( self, x: mx.array, cache: mx.array = None, ) -> Tuple[mx.array, mx.array]: - y, cache = self.model(x, cache) - return self.lm_head(y), cache + y = self.model(x, cache) + return self.lm_head(y) @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 13d758e8..e7d38614 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -81,11 +81,9 @@ class Attention(nn.Module): values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -94,7 +92,7 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class MLP(nn.Module): @@ -128,11 +126,11 @@ class TransformerBlock(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r - return out, cache + return out class Phi3Model(nn.Module): @@ -163,10 +161,10 @@ class Phi3Model(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -175,15 +173,24 @@ class Model(nn.Module): self.model_type = args.model_type self.model = Phi3Model(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.args = args def __call__( self, inputs: mx.array, cache=None, ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + out = self.model(inputs, cache) + return self.lm_head(out) @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index fa4a24a4..7413e3cd 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -11,7 +11,6 @@ import numpy as np @dataclass class ModelArgs: model_type: str - max_sequence_length: int = 2048 num_vocab: int = 51200 model_dim: int = 2560 num_heads: int = 32 @@ -56,11 +55,9 @@ class RoPEAttention(nn.Module): # Add RoPE to the queries and keys and combine them with the cache if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -69,14 +66,13 @@ class RoPEAttention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores = scores + mask - scores = mx.softmax(scores, axis=-1).astype(values.dtype) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries.astype(mx.float32), keys, values, scale=scale, mask=mask + ).astype(values.dtype) + output = output.moveaxis(2, 1).reshape(B, L, -1) - return self.out_proj(values_hat), (keys, values) + return self.out_proj(output) class MLP(nn.Module): @@ -144,9 +140,9 @@ class ParallelBlock(nn.Module): def __call__(self, x, mask, cache): h = self.ln(x) - attn_h, cache = self.mixer(h, mask, cache) + attn_h = self.mixer(h, mask, cache) ff_h = self.moe(h) - return attn_h + ff_h + x, cache + return attn_h + ff_h + x class TransformerDecoder(nn.Module): @@ -160,9 +156,9 @@ class TransformerDecoder(nn.Module): if cache is None: cache = [None] * len(self.h) - for e, layer in enumerate(self.h): - x, cache[e] = layer(x, mask, cache[e]) - return x, cache + for layer, c in zip(self.h, cache): + x = layer(x, mask, c) + return x class Embd(nn.Module): @@ -190,6 +186,7 @@ class Model(nn.Module): self.model_type = config.model_type self.transformer = TransformerDecoder(config) self.lm_head = OutputHead(config) + self.args = config def __call__( self, @@ -202,9 +199,17 @@ class Model(nn.Module): mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(x.dtype) - y, cache = self.transformer(x, mask, cache) - return self.lm_head(y), cache + y = self.transformer(x, mask, cache) + return self.lm_head(y) @property def layers(self): return self.transformer.h + + @property + def head_dim(self): + return self.args.model_dim // self.args.num_heads + + @property + def n_kv_heads(self): + return self.args.num_heads diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 53c1252c..2d0ddaed 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -64,44 +64,38 @@ class Attention(nn.Module): ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: bsz, q_len, _ = hidden_states.shape - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) # Prepare the queries, keys and values for the attention computation - query_states = query_states.reshape( - bsz, q_len, self.q_num_heads, self.qk_dim - ).transpose(0, 2, 1, 3) - key_states = key_states.reshape( - bsz, q_len, self.k_num_heads, self.qk_dim - ).transpose(0, 2, 1, 3) - value_states = value_states.reshape( - bsz, q_len, self.v_num_heads, self.v_dim - ).transpose(0, 2, 1, 3) - - # expand shared kv - assert self.k_num_heads == self.v_num_heads - - kv_seq_len = 0 - if cache is not None: - kv_seq_len += cache[0].shape[-2] - query_states = self.rotary_emb(query_states, offset=kv_seq_len) - key_states = self.rotary_emb(key_states, offset=kv_seq_len) + queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose( + 0, 2, 1, 3 + ) + keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose( + 0, 2, 1, 3 + ) + values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose( + 0, 2, 1, 3 + ) if cache is not None: - # reuse k, v, self_attention - key_states = mx.concatenate([cache[0], key_states], axis=2) - value_states = mx.concatenate([cache[1], value_states], axis=2) + queries = self.rotary_emb(queries, offset=cache.offset) + keys = self.rotary_emb(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rotary_emb(queries) + keys = self.rotary_emb(keys) output = mx.fast.scaled_dot_product_attention( - query_states, - key_states, - value_states, + queries, + keys, + values, scale=self.scale, mask=attention_mask, ) output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1) - return self.o_proj(output), (key_states, value_states) + return self.o_proj(output) class MLP(nn.Module): @@ -139,7 +133,7 @@ class PlamoDecoderLayer(nn.Module): hidden_states = self.norm(hidden_states) # Self Attention - hidden_states_sa, cache = self.self_attn( + hidden_states_sa = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, cache=cache, @@ -149,7 +143,7 @@ class PlamoDecoderLayer(nn.Module): hidden_states_mlp = self.mlp(hidden_states) hidden_states = residual + hidden_states_sa + hidden_states_mlp - return hidden_states, cache + return hidden_states class PlamoDecoder(nn.Module): @@ -185,14 +179,10 @@ class PlamoModel(nn.Module): if cache is None: cache = [None for _ in range(len(self.layers.layers))] - for e, layer in enumerate(self.layers.layers): - h, c = layer(h, mask, cache[e]) - if cache is not None: - cache[e] = c - else: - cache.append(c) + for layer, c in zip(self.layers.layers, cache): + h = layer(h, mask, cache=c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -203,15 +193,24 @@ class Model(nn.Module): self.lm_head: nn.Module = nn.Linear( args.hidden_size, args.vocab_size, bias=False ) + self.args = args def __call__( self, inputs: mx.array, cache: Optional[List[Tuple[mx.array, mx.array]]] = None, ) -> Tuple[mx.array, mx.array]: - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + out = self.model(inputs, cache) + return self.lm_head(out) @property def layers(self): return self.model.layers.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_attention_heads // self.args.n_shared_head diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index a4e82dd2..44b6dfd3 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -51,30 +51,24 @@ class Attention(nn.Module): B, L, _ = q.shape - q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - k_cache, v_cache = cache - q = self.rotary_emb(q, offset=k_cache.shape[2]) - k = self.rotary_emb(k, offset=k_cache.shape[2]) - k = mx.concatenate([k_cache, k], axis=2) - v = mx.concatenate([v_cache, v], axis=2) - + queries = self.rotary_emb(queries, offset=cache.offset) + keys = self.rotary_emb(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: - q = self.rotary_emb(q) - k = self.rotary_emb(k) + queries = self.rotary_emb(queries) + keys = self.rotary_emb(keys) - scores = (q * self.scale) @ k.transpose(0, 1, 3, 2) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - if mask is not None: - scores = scores + mask - - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.c_proj(v_hat), (k, v) + return self.c_proj(output) class MLP(nn.Module): @@ -109,13 +103,13 @@ class TransformerBlock(nn.Module): def __call__(self, x, mask=None, cache=None): residual = x x = self.ln_1(x) - x, cache = self.attn(x, mask=mask, cache=cache) + x = self.attn(x, mask=mask, cache=cache) residual = x + residual x = self.ln_2(residual) x = self.mlp(x) x = x + residual - return x, cache + return x class QwenModel(nn.Module): @@ -137,10 +131,10 @@ class QwenModel(nn.Module): if cache is None: cache = [None] * len(self.h) - for e, layer in enumerate(self.h): - x, cache[e] = layer(x, mask, cache[e]) + for layer, c in zip(self.h, cache): + x = layer(x, mask, c) - return self.ln_f(x), cache + return self.ln_f(x) class Model(nn.Module): @@ -151,6 +145,7 @@ class Model(nn.Module): self.lm_head = nn.Linear( config.hidden_size, config.vocab_size, bias=not config.no_bias ) + self.args = config def __call__( self, @@ -158,9 +153,17 @@ class Model(nn.Module): mask: mx.array = None, cache: mx.array = None, ) -> Tuple[mx.array, mx.array]: - y, cache = self.transformer(x, mask, cache) - return self.lm_head(y), cache + y = self.transformer(x, mask, cache) + return self.lm_head(y) @property def layers(self): return self.transformer.h + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_attention_heads diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index d95893f9..b928de09 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -79,11 +79,9 @@ class Attention(nn.Module): values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -92,7 +90,7 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class MLP(nn.Module): @@ -125,11 +123,11 @@ class TransformerBlock(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r - return out, cache + return out class Qwen2Model(nn.Module): @@ -160,10 +158,10 @@ class Qwen2Model(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -180,12 +178,12 @@ class Model(nn.Module): inputs: mx.array, cache=None, ): - out, cache = self.model(inputs, cache) + out = self.model(inputs, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: out = self.lm_head(out) - return out, cache + return out def sanitize(self, weights): if self.args.tie_word_embeddings: @@ -198,3 +196,11 @@ class Model(nn.Module): @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index abe9452c..ea8ab802 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -78,11 +78,9 @@ class Attention(nn.Module): values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -91,7 +89,7 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class Qwen2MoeMLP(nn.Module): @@ -187,11 +185,11 @@ class Qwen2MoeDecoderLayer(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r - return out, cache + return out class Qwen2MoeModel(nn.Module): @@ -222,10 +220,10 @@ class Qwen2MoeModel(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -241,8 +239,8 @@ class Model(nn.Module): inputs: mx.array, cache=None, ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + out = self.model(inputs, cache) + return self.lm_head(out) def sanitize(self, weights): if self.args.tie_word_embeddings and "lm_head.weight" not in weights: @@ -255,3 +253,11 @@ class Model(nn.Module): @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 47af5295..30e3a332 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -107,11 +107,9 @@ class Attention(nn.Module): # Add RoPE to the queries and keys and combine them with the cache if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -125,7 +123,7 @@ class Attention(nn.Module): queries, keys, values, scale=scale, mask=mask ).astype(values.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class MLP(nn.Module): @@ -157,7 +155,7 @@ class DecoderLayer(nn.Module): def __call__(self, x, mask, cache): h = self.input_layernorm(x) - r, cache = self.self_attn(h, mask, cache) + r = self.self_attn(h, mask, cache) if self.use_parallel_residual: out = x + r + self.mlp(h) @@ -165,7 +163,7 @@ class DecoderLayer(nn.Module): h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r - return out, cache + return out class StableLM(nn.Module): @@ -180,9 +178,10 @@ class StableLM(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - x, cache[e] = layer(x, mask, cache[e]) - return self.norm(x), cache + for layer, c in zip(self.layers, cache): + x = layer(x, mask, cache=c) + + return self.norm(x) class Model(nn.Module): @@ -191,6 +190,7 @@ class Model(nn.Module): self.model_type = config.model_type self.model = StableLM(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.args = config def __call__( self, @@ -203,9 +203,17 @@ class Model(nn.Module): mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(x.dtype) - y, cache = self.model(x, mask, cache) - return self.lm_head(y), cache + y = self.model(x, mask, cache) + return self.lm_head(y) @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index e96db5b8..ca06bdb1 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -55,11 +55,9 @@ class Attention(nn.Module): values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -69,7 +67,7 @@ class Attention(nn.Module): ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class MLP(nn.Module): @@ -102,11 +100,11 @@ class TransformerBlock(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r - return out, cache + return out class Starcoder2Model(nn.Module): @@ -137,10 +135,10 @@ class Starcoder2Model(nn.Module): if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) - return self.norm(h), cache + return self.norm(h) class Model(nn.Module): @@ -157,13 +155,21 @@ class Model(nn.Module): inputs: mx.array, cache=None, ): - out, cache = self.model(inputs, cache) + out = self.model(inputs, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: out = self.lm_head(out) - return out, cache + return out @property def layers(self): return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index bfc7bde1..7e251a09 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -314,7 +314,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}): tokenizer_file = model_path / "tokenizer.json" if tokenizer_file.exists(): - tokenizer_content = json.load(tokenizer_file.open()) + with open(tokenizer_file, "r") as fid: + tokenizer_content = json.load(fid) if "decoder" in tokenizer_content: if _is_spm_decoder(tokenizer_content["decoder"]): detokenizer_class = SPMStreamingDetokenizer diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5a190b7d..b3667609 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -18,6 +18,7 @@ from mlx.utils import tree_flatten from transformers import AutoTokenizer, PreTrainedTokenizer # Local imports +from .models.base import KVCache from .sample_utils import top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import apply_lora_layers @@ -160,7 +161,12 @@ def generate_step( ) y = prompt - cache = None + kv_heads = ( + [model.n_kv_heads] * len(model.layers) + if isinstance(model.n_kv_heads, int) + else model.n_kv_heads + ) + cache = [KVCache(model.head_dim, n) for n in kv_heads] repetition_context = prompt.tolist() @@ -168,8 +174,8 @@ def generate_step( repetition_context = repetition_context[-repetition_context_size:] def _step(y): - nonlocal cache, repetition_context - logits, cache = model(y[None], cache=cache) + nonlocal repetition_context + logits = model(y[None], cache=cache) logits = logits[:, -1, :] if repetition_penalty: @@ -445,9 +451,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): card.text = dedent( f""" # {upload_repo} - + The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. - + ## Use with mlx ```bash diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 3e8b1fe1..25ba8398 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.12.0" +__version__ = "0.13.0" diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 5435924c..225e9d27 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -4,6 +4,7 @@ import unittest import mlx.core as mx from mlx.utils import tree_map +from mlx_lm.models.base import KVCache class TestModels(unittest.TestCase): @@ -17,13 +18,18 @@ class TestModels(unittest.TestCase): model.update(tree_map(lambda p: p.astype(t), model.parameters())) inputs = mx.array([[0, 1]]) - outputs, cache = model(inputs) + outputs = model(inputs) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - outputs, cache = model( - mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache + kv_heads = ( + [model.n_kv_heads] * len(model.layers) + if isinstance(model.n_kv_heads, int) + else model.n_kv_heads ) + cache = [KVCache(model.head_dim, n) for n in kv_heads] + + outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache) self.assertEqual(outputs.shape, (1, 1, vocab_size)) self.assertEqual(outputs.dtype, t) @@ -53,6 +59,15 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_phixtral(self): + from mlx_lm.models import phixtral + + args = phixtral.ModelArgs( + "phixtral", num_vocab=1000, num_layers=4, model_dim=1024 + ) + model = phixtral.Model(args) + self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers) + def test_phi3(self): from mlx_lm.models import phi3 @@ -264,6 +279,100 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_dbrx(self): + from mlx_lm.models import dbrx + + args = dbrx.ModelArgs( + model_type="dbrx", + d_model=1024, + ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2}, + attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000}, + n_layers=4, + n_heads=4, + vocab_size=10_000, + ) + model = dbrx.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers) + + def test_minicpm(self): + from mlx_lm.models import minicpm + + args = minicpm.ModelArgs( + model_type="minicpm", + hidden_size=1024, + dim_model_base=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-4, + vocab_size=10000, + num_key_value_heads=2, + scale_depth=1.0, + scale_emb=1.0, + ) + model = minicpm.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_openelm(self): + from mlx_lm.models import openelm + + args = openelm.ModelArgs( + model_type="openelm", + ffn_dim_divisor=256, + ffn_multipliers=[ + 0.5, + 0.73, + 0.97, + 1.2, + 1.43, + 1.67, + 1.9, + 2.13, + 2.37, + 2.6, + 2.83, + 3.07, + 3.3, + 3.53, + 3.77, + 4.0, + ], + head_dim=64, + model_dim=1280, + normalize_qk_projections=True, + num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5], + num_query_heads=[ + 12, + 12, + 12, + 12, + 12, + 16, + 16, + 16, + 16, + 16, + 16, + 16, + 20, + 20, + 20, + 20, + ], + num_transformer_layers=16, + vocab_size=32000, + ) + + model = openelm.Model(args) + self.model_test_runner( + model, + args.model_type, + args.vocab_size, + len(args.ffn_multipliers), + ) + if __name__ == "__main__": unittest.main()