diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 8c3ecc78..3fe276d2 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -1,14 +1,9 @@ import inspect from dataclasses import dataclass +from typing import List, Optional import mlx.core as mx - - -def create_additive_causal_mask(N: int, offset: int = 0): - rinds = mx.arange(offset + N) - linds = mx.arange(offset, offset + N) if offset else rinds - mask = linds[:, None] < rinds[None] - return mask * -1e9 +import mlx.nn as nn class KVCache: @@ -29,9 +24,10 @@ class KVCache: def update_and_fetch(self, keys, values): prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: + B = keys.shape[0] n_steps = (self.step + keys.shape[2] - 1) // self.step - k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim) - v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim) + k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim) + v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: @@ -60,3 +56,24 @@ class BaseModelArgs: if k in inspect.signature(cls).parameters } ) + + +def create_additive_causal_mask(N: int, offset: int = 0): + rinds = mx.arange(offset + N) + linds = mx.arange(offset, offset + N) if offset else rinds + mask = linds[:, None] < rinds[None] + return mask * -1e9 + + +def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None): + T = h.shape[1] + if T > 1: + # Input consists of multiple tokens, create a causal mask so that prior + # tokens do not give attention to later tokens. If a cache is in place + # (because e.g. prompt reuse), offset the mask accordingly. + offset = cache[0].offset if cache is not None and cache[0] is not None else 0 + mask = create_additive_causal_mask(T, offset) + mask = mask.astype(h.dtype) + else: + mask = None + return mask diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 621a85a2..7dc2b9bf 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -157,10 +157,7 @@ class CohereModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index dc310ca4..7a2a7a7d 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -199,11 +199,7 @@ class DBRX(nn.Module): ): h = self.wte(inputs) - mask = None - T = h.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 308b94ba..bd743e53 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask from .switch_layers import SwitchGLU @@ -408,11 +408,7 @@ class DeepseekV2Model(nn.Module): cache: Optional[KVCache] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = None - T = h.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index e48f1909..323ebaa6 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -141,10 +141,7 @@ class GemmaModel(nn.Module): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 1ab403da..d4bd8a5d 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -165,10 +165,7 @@ class GemmaModel(nn.Module): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index ece7b6ec..81f71cac 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_additive_causal_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -136,10 +136,7 @@ class GPT2Model(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_additive_causal_mask( - hidden_states.shape[1], cache[0].offset if cache is not None else 0 - ) - mask = mask.astype(hidden_states.dtype) + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 20af3d0b..a5336203 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_additive_causal_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -147,10 +147,7 @@ class GPTBigCodeModel(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_additive_causal_mask( - hidden_states.shape[1], cache[0].offset if cache is not None else 0 - ) - mask = mask.astype(hidden_states.dtype) + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index 9549f322..1d2f74b7 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_additive_causal_mask +from .base import BaseModelArgs, create_attention_mask # Based on the transformers implementation at: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -150,12 +150,7 @@ class GPTNeoXModel(nn.Module): hidden_states = self.embed_in(inputs) - mask = None - if hidden_states.shape[1] > 1: - mask = create_additive_causal_mask( - hidden_states.shape[1], cache[0].offset if cache is not None else 0 - ) - mask = mask.astype(hidden_states.dtype) + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index fc1cfc03..2ee2af2d 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -195,10 +195,7 @@ class InternLM2Model(nn.Module): ): h = self.tok_embeddings(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index a697e2b7..2f323245 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_additive_causal_mask +from .base import BaseModelArgs, KVCache, create_attention_mask @dataclass @@ -271,12 +271,7 @@ class LlamaModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = create_additive_causal_mask( - h.shape[1], cache[0].offset if cache is not None else 0 - ) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index dbfe4186..a3d01cbb 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -160,10 +160,7 @@ class MiniCPMModel(nn.Module): ): h = self.embed_tokens(inputs) * self.args.scale_emb - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 7d1b10ac..c7d8c5c5 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -164,11 +164,7 @@ class MixtralModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - T = h.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 120ea9b9..8a28ad74 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask try: import hf_olmo @@ -126,10 +126,7 @@ class Transformer(nn.Module): ): h = self.wte(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 3fbdc58c..3f0d2605 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -180,10 +180,7 @@ class OpenELMModel(nn.Module): ): h = self.token_embeddings(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 8feaa23a..520ac1ad 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -5,7 +5,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -138,14 +138,12 @@ class PhiModel(nn.Module): def __call__(self, x, cache): x = self.embed_tokens(x) + + mask = create_attention_mask(x, cache) + if cache is None: cache = [None] * len(self.layers) - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - for layer, c in zip(self.layers, cache): x = layer(x, mask, c) return self.final_layernorm(x) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index dd2d6d82..2536aacb 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask from .su_rope import SuScaledRotaryEmbedding @@ -172,10 +172,7 @@ class Phi3Model(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index e0f2d856..de075652 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -6,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask @dataclass @@ -263,10 +263,7 @@ class Phi3Model(nn.Module): if self.mup_embedding_multiplier: h = self.mup_embedding_multiplier * h - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 40a3bc4b..f0aef0c9 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -6,6 +6,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn +from .base import create_attention_mask from .switch_layers import SwitchMLP @@ -167,10 +168,7 @@ class Model(nn.Module): mask: mx.array = None, cache: mx.array = None, ) -> Tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) + mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 2d0ddaed..47a9ea4f 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -171,10 +171,7 @@ class PlamoModel(nn.Module): ) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]: h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(self.embed_tokens.weight.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None for _ in range(len(self.layers.layers))] diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 44b6dfd3..67816599 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -4,7 +4,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -122,11 +122,7 @@ class QwenModel(nn.Module): def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) - mask = None - T = x.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(x.dtype) + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index fab09003..cb8268aa 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask @dataclass @@ -151,10 +151,7 @@ class Qwen2Model(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 57f154a0..121ab813 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask from .switch_layers import SwitchGLU @@ -189,10 +189,7 @@ class Qwen2MoeModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 30e3a332..9b4d043c 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -5,7 +5,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -198,11 +198,7 @@ class Model(nn.Module): mask: mx.array = None, cache: mx.array = None, ) -> Tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - + mask = create_attention_mask(x, cache) y = self.model(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 7b058d8f..a6eb5377 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask @dataclass @@ -127,10 +127,7 @@ class Starcoder2Model(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers)