From 46da74fea2404eb987e8070d84f5745a124dbbc8 Mon Sep 17 00:00:00 2001 From: otriscon <165947759+otriscon@users.noreply.github.com> Date: Thu, 25 Jul 2024 19:45:22 -0400 Subject: [PATCH] Unify attention mask in LLMs (#911) * Unify attention mask creation in LLMs. Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc code to create a mask for the attention mechanism. This usually takes the form: ``` mask = None if h.shape[1] > 1: mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) ``` This correctly creates a mask only if the input consists of more than one token. But this code assumes the multi-token input is at the beginning of inference. If, for example, we are evaluating multiple tokens because of speculative decoding or prompt cache reuse, this mask will not have the correct shape and and will cause the raising of an exception in the attention computation. Some of the models correctly implement the mask creation with code like this: ``` mask = None if h.shape[1] > 1: mask = create_additive_causal_mask( h.shape[1], cache[0].offset if cache is not None else 0 ) mask = mask.astype(h.dtype) ``` This commit unifies the attention mask creation for all models with a new function `create_attention_mask`, reducing code duplication and helping all models support inference performance enhancements like those mentioned above. * Allow batches in LLM key-value cache The current implementation of the LLM key-value cache assumes that the input batch is of size 1. Input batching (evaluating multiple alterative inputs at the same time) can be a valuable tool for speculative sampling and other techniques. This change removes the hard-coded batch size from the code that resizes the key-value cache. * Simplify causal mask creation Use the same codepath regardless of whether there's an offset or not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717). * Use old-style type annotation to avoid linter error --- llms/mlx_lm/models/base.py | 35 +++++++++++++++++++++++-------- llms/mlx_lm/models/cohere.py | 7 ++----- llms/mlx_lm/models/dbrx.py | 8 ++----- llms/mlx_lm/models/deepseek_v2.py | 8 ++----- llms/mlx_lm/models/gemma.py | 7 ++----- llms/mlx_lm/models/gemma2.py | 7 ++----- llms/mlx_lm/models/gpt2.py | 7 ++----- llms/mlx_lm/models/gpt_bigcode.py | 7 ++----- llms/mlx_lm/models/gpt_neox.py | 9 ++------ llms/mlx_lm/models/internlm2.py | 7 ++----- llms/mlx_lm/models/llama.py | 9 ++------ llms/mlx_lm/models/minicpm.py | 7 ++----- llms/mlx_lm/models/mixtral.py | 8 ++----- llms/mlx_lm/models/olmo.py | 7 ++----- llms/mlx_lm/models/openelm.py | 7 ++----- llms/mlx_lm/models/phi.py | 10 ++++----- llms/mlx_lm/models/phi3.py | 7 ++----- llms/mlx_lm/models/phi3small.py | 7 ++----- llms/mlx_lm/models/phixtral.py | 6 ++---- llms/mlx_lm/models/plamo.py | 7 ++----- llms/mlx_lm/models/qwen.py | 8 ++----- llms/mlx_lm/models/qwen2.py | 7 ++----- llms/mlx_lm/models/qwen2_moe.py | 7 ++----- llms/mlx_lm/models/stablelm.py | 8 ++----- llms/mlx_lm/models/starcoder2.py | 7 ++----- 25 files changed, 76 insertions(+), 138 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 8c3ecc78..3fe276d2 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -1,14 +1,9 @@ import inspect from dataclasses import dataclass +from typing import List, Optional import mlx.core as mx - - -def create_additive_causal_mask(N: int, offset: int = 0): - rinds = mx.arange(offset + N) - linds = mx.arange(offset, offset + N) if offset else rinds - mask = linds[:, None] < rinds[None] - return mask * -1e9 +import mlx.nn as nn class KVCache: @@ -29,9 +24,10 @@ class KVCache: def update_and_fetch(self, keys, values): prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: + B = keys.shape[0] n_steps = (self.step + keys.shape[2] - 1) // self.step - k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim) - v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim) + k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim) + v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: @@ -60,3 +56,24 @@ class BaseModelArgs: if k in inspect.signature(cls).parameters } ) + + +def create_additive_causal_mask(N: int, offset: int = 0): + rinds = mx.arange(offset + N) + linds = mx.arange(offset, offset + N) if offset else rinds + mask = linds[:, None] < rinds[None] + return mask * -1e9 + + +def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None): + T = h.shape[1] + if T > 1: + # Input consists of multiple tokens, create a causal mask so that prior + # tokens do not give attention to later tokens. If a cache is in place + # (because e.g. prompt reuse), offset the mask accordingly. + offset = cache[0].offset if cache is not None and cache[0] is not None else 0 + mask = create_additive_causal_mask(T, offset) + mask = mask.astype(h.dtype) + else: + mask = None + return mask diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 621a85a2..7dc2b9bf 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -157,10 +157,7 @@ class CohereModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index dc310ca4..7a2a7a7d 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -199,11 +199,7 @@ class DBRX(nn.Module): ): h = self.wte(inputs) - mask = None - T = h.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 308b94ba..bd743e53 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask from .switch_layers import SwitchGLU @@ -408,11 +408,7 @@ class DeepseekV2Model(nn.Module): cache: Optional[KVCache] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = None - T = h.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index e48f1909..323ebaa6 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -141,10 +141,7 @@ class GemmaModel(nn.Module): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 1ab403da..d4bd8a5d 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -165,10 +165,7 @@ class GemmaModel(nn.Module): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index ece7b6ec..81f71cac 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_additive_causal_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -136,10 +136,7 @@ class GPT2Model(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_additive_causal_mask( - hidden_states.shape[1], cache[0].offset if cache is not None else 0 - ) - mask = mask.astype(hidden_states.dtype) + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 20af3d0b..a5336203 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_additive_causal_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -147,10 +147,7 @@ class GPTBigCodeModel(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_additive_causal_mask( - hidden_states.shape[1], cache[0].offset if cache is not None else 0 - ) - mask = mask.astype(hidden_states.dtype) + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index 9549f322..1d2f74b7 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_additive_causal_mask +from .base import BaseModelArgs, create_attention_mask # Based on the transformers implementation at: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -150,12 +150,7 @@ class GPTNeoXModel(nn.Module): hidden_states = self.embed_in(inputs) - mask = None - if hidden_states.shape[1] > 1: - mask = create_additive_causal_mask( - hidden_states.shape[1], cache[0].offset if cache is not None else 0 - ) - mask = mask.astype(hidden_states.dtype) + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index fc1cfc03..2ee2af2d 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -195,10 +195,7 @@ class InternLM2Model(nn.Module): ): h = self.tok_embeddings(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index a697e2b7..2f323245 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_additive_causal_mask +from .base import BaseModelArgs, KVCache, create_attention_mask @dataclass @@ -271,12 +271,7 @@ class LlamaModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = create_additive_causal_mask( - h.shape[1], cache[0].offset if cache is not None else 0 - ) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index dbfe4186..a3d01cbb 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -160,10 +160,7 @@ class MiniCPMModel(nn.Module): ): h = self.embed_tokens(inputs) * self.args.scale_emb - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 7d1b10ac..c7d8c5c5 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -164,11 +164,7 @@ class MixtralModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - T = h.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 120ea9b9..8a28ad74 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask try: import hf_olmo @@ -126,10 +126,7 @@ class Transformer(nn.Module): ): h = self.wte(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 3fbdc58c..3f0d2605 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -180,10 +180,7 @@ class OpenELMModel(nn.Module): ): h = self.token_embeddings(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 8feaa23a..520ac1ad 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -5,7 +5,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -138,14 +138,12 @@ class PhiModel(nn.Module): def __call__(self, x, cache): x = self.embed_tokens(x) + + mask = create_attention_mask(x, cache) + if cache is None: cache = [None] * len(self.layers) - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - for layer, c in zip(self.layers, cache): x = layer(x, mask, c) return self.final_layernorm(x) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index dd2d6d82..2536aacb 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask from .su_rope import SuScaledRotaryEmbedding @@ -172,10 +172,7 @@ class Phi3Model(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index e0f2d856..de075652 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -6,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask @dataclass @@ -263,10 +263,7 @@ class Phi3Model(nn.Module): if self.mup_embedding_multiplier: h = self.mup_embedding_multiplier * h - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 40a3bc4b..f0aef0c9 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -6,6 +6,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn +from .base import create_attention_mask from .switch_layers import SwitchMLP @@ -167,10 +168,7 @@ class Model(nn.Module): mask: mx.array = None, cache: mx.array = None, ) -> Tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) + mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 2d0ddaed..47a9ea4f 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -171,10 +171,7 @@ class PlamoModel(nn.Module): ) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]: h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(self.embed_tokens.weight.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None for _ in range(len(self.layers.layers))] diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 44b6dfd3..67816599 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -4,7 +4,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -122,11 +122,7 @@ class QwenModel(nn.Module): def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) - mask = None - T = x.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(x.dtype) + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index fab09003..cb8268aa 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask @dataclass @@ -151,10 +151,7 @@ class Qwen2Model(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 57f154a0..121ab813 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask from .switch_layers import SwitchGLU @@ -189,10 +189,7 @@ class Qwen2MoeModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 30e3a332..9b4d043c 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -5,7 +5,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -198,11 +198,7 @@ class Model(nn.Module): mask: mx.array = None, cache: mx.array = None, ) -> Tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - + mask = create_attention_mask(x, cache) y = self.model(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 7b058d8f..a6eb5377 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache +from .base import BaseModelArgs, KVCache, create_attention_mask @dataclass @@ -127,10 +127,7 @@ class Starcoder2Model(nn.Module): ): h = self.embed_tokens(inputs) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers)