From d4ef909d4ab44d9f8cf89f5baa8a433d76d7d6b1 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 18 Dec 2024 19:43:52 -0800 Subject: [PATCH] Length masking for batch inputs (#1173) * length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix --- llms/mlx_lm/models/base.py | 10 +++++++++- llms/mlx_lm/models/cohere.py | 7 +++++-- llms/mlx_lm/models/cohere2.py | 14 ++++++-------- llms/mlx_lm/models/dbrx.py | 7 +++++-- llms/mlx_lm/models/deepseek.py | 7 +++++-- llms/mlx_lm/models/deepseek_v2.py | 8 ++++++-- llms/mlx_lm/models/exaone.py | 7 +++++-- llms/mlx_lm/models/gemma.py | 7 +++++-- llms/mlx_lm/models/gemma2.py | 7 +++++-- llms/mlx_lm/models/gpt2.py | 7 +++++-- llms/mlx_lm/models/gpt_bigcode.py | 7 +++++-- llms/mlx_lm/models/gpt_neox.py | 7 +++++-- llms/mlx_lm/models/hunyuan.py | 7 +++++-- llms/mlx_lm/models/internlm2.py | 7 +++++-- llms/mlx_lm/models/llama.py | 7 +++++-- llms/mlx_lm/models/minicpm.py | 7 +++++-- llms/mlx_lm/models/mixtral.py | 7 +++++-- llms/mlx_lm/models/nemotron.py | 7 +++++-- llms/mlx_lm/models/olmo.py | 10 +++++++--- llms/mlx_lm/models/olmo2.py | 7 +++++-- llms/mlx_lm/models/openelm.py | 7 +++++-- llms/mlx_lm/models/phi.py | 8 +++++--- llms/mlx_lm/models/phi3.py | 7 +++++-- llms/mlx_lm/models/phi3small.py | 7 +++++-- llms/mlx_lm/models/phimoe.py | 7 +++++-- llms/mlx_lm/models/phixtral.py | 4 +++- llms/mlx_lm/models/plamo.py | 7 +++++-- llms/mlx_lm/models/qwen.py | 3 ++- llms/mlx_lm/models/qwen2.py | 7 +++++-- llms/mlx_lm/models/qwen2_moe.py | 7 +++++-- llms/mlx_lm/models/recurrent_gemma.py | 8 +++++--- llms/mlx_lm/models/stablelm.py | 5 ++++- llms/mlx_lm/models/starcoder2.py | 7 +++++-- llms/tests/test_models.py | 25 ++++++++++++++++++++++++- 34 files changed, 191 insertions(+), 72 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index f02f49b1..ad7a4a65 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -23,7 +23,12 @@ class BaseModelArgs: ) -def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): +def create_causal_mask( + N: int, + offset: int = 0, + window_size: Optional[int] = None, + lengths: Optional[mx.array] = None, +): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds linds = linds[:, None] @@ -31,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non mask = linds < rinds if window_size is not None: mask = mask | (linds > rinds + window_size) + if lengths is not None: + lengths = lengths[:, None, None, None] + mask = mask | (rinds >= lengths) return mask * -1e9 diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 7e002b0c..b2d16dd7 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -155,11 +155,13 @@ class CohereModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -180,9 +182,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index fcb4061b..ec0e9276 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import KVCache, RotatingKVCache @@ -151,16 +151,13 @@ class CohereModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - T = h.shape[1] - if T > 1: - offset = cache[0].offset if cache else 0 - mask = create_causal_mask(T, offset).astype(h.dtype) - else: - mask = None + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -181,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 7be274cc..886b5630 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -197,11 +197,13 @@ class DBRX(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -223,9 +225,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index b7b24dba..ffc30c36 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -211,9 +211,11 @@ class DeepseekModel(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -236,8 +238,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 444813b9..9027da7e 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -370,9 +370,12 @@ class DeepseekV2Model(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -395,8 +398,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py index eaed5dd8..ee3ed1e8 100644 --- a/llms/mlx_lm/models/exaone.py +++ b/llms/mlx_lm/models/exaone.py @@ -123,10 +123,12 @@ class ExaoneModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.h) @@ -149,9 +151,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 3f384c3f..0860ddeb 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -138,12 +138,14 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -164,9 +166,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) return out diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 64951ae4..321a58ff 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -160,12 +160,14 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -187,9 +189,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = mx.tanh(out / self.final_logit_softcapping) out = out * self.final_logit_softcapping diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 52076a34..5b277734 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -126,6 +126,7 @@ class GPT2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape @@ -138,7 +139,8 @@ class GPT2Model(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -159,9 +161,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.wte.as_linear(out) return out diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 23e86e20..8415c59e 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): B, L = inputs.shape @@ -149,7 +150,8 @@ class GPTBigCodeModel(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -172,9 +174,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index ccb0b28b..5e124a67 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -146,13 +146,15 @@ class GPTNeoXModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape hidden_states = self.embed_in(inputs) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -176,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return out def sanitize(self, weights): diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py index b098c20d..f9dc5652 100644 --- a/llms/mlx_lm/models/hunyuan.py +++ b/llms/mlx_lm/models/hunyuan.py @@ -239,11 +239,13 @@ class HunYuanModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -266,9 +268,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.model.embed_tokens.as_linear(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index f5ce057e..28a095e1 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -193,11 +193,13 @@ class InternLM2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.tok_embeddings(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -220,9 +222,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.tok_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 290cb83e..7b452ea4 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -155,11 +155,13 @@ class LlamaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -182,9 +184,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 907beb2a..edddd583 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -158,11 +158,13 @@ class MiniCPMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) * self.args.scale_emb - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -186,9 +188,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if not self.args.tie_word_embeddings: out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index dd94d1f4..0afd1235 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -162,11 +162,13 @@ class MixtralModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -188,9 +190,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index f73c0277..eabfac8c 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -176,11 +176,13 @@ class NemotronModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -203,9 +205,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 3627df06..4273b0ec 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -124,11 +124,13 @@ class Transformer(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -152,9 +154,10 @@ class OlmoModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.transformer(inputs, cache) + return self.transformer(inputs, mask, cache) class Model(nn.Module): @@ -167,9 +170,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.model(inputs, cache) + return self.model(inputs, mask, cache) @property def layers(self): diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py index 64d7e116..510ff882 100644 --- a/llms/mlx_lm/models/olmo2.py +++ b/llms/mlx_lm/models/olmo2.py @@ -163,10 +163,12 @@ class LlamaModel(nn.Module): self, inputs: mx.array, cache=None, + mask=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -190,8 +192,9 @@ class Model(nn.Module): self, inputs: mx.array, cache=None, + mask=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 408802f4..504fe95c 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -178,11 +178,13 @@ class OpenELMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.token_embeddings(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -205,9 +207,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.share_input_output_layers: out = self.transformer.token_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 510025ea..e9724691 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -143,10 +143,11 @@ class PhiModel(nn.Module): config.hidden_size, eps=config.layer_norm_eps ) - def __call__(self, x, cache): + def __call__(self, x, mask, cache): x = self.embed_tokens(x) - mask = create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.layers) @@ -167,9 +168,10 @@ class Model(nn.Module): def __call__( self, x: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: - y = self.model(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) @property diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index ee6efc49..d1c21e25 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -168,11 +168,13 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -194,9 +196,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 53e1a638..cd566eec 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -258,13 +258,15 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) if self.mup_embedding_multiplier: h = self.mup_embedding_multiplier * h - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -290,9 +292,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) if self.mup_width_multiplier: out = out / self.mup_width_multiplier diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index f42a6dd0..bddcb128 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -155,11 +155,13 @@ class PhiMoEModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -181,9 +183,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 42d647b0..5477c2c0 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -175,7 +175,9 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index c8e5bf50..9107daad 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -174,10 +174,12 @@ class PlamoModel(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None for _ in range(len(self.layers.layers))] @@ -202,8 +204,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 8145a890..ec8a0199 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -123,7 +123,8 @@ class QwenModel(nn.Module): def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) - mask = create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index fac59d78..381767c4 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -149,11 +149,13 @@ class Qwen2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -176,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 167fc5dd..c6aba622 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -187,11 +187,13 @@ class Qwen2MoeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -213,9 +215,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 49e4bb8f..ad07d925 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -389,6 +389,7 @@ class Griffin(nn.Module): def __call__( self, tokens, + mask: mx.array = None, cache=None, ): x = self.embed_tokens(tokens) @@ -402,7 +403,8 @@ class Griffin(nn.Module): if block.temporal_block_type != "recurrent": mask_cache = [cache[i]] - mask = create_attention_mask(x, mask_cache) + if mask is None: + mask = create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) @@ -418,12 +420,12 @@ class Model(nn.Module): self.model_type = config.model_type self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - def __call__(self, tokens: mx.array, cache=None) -> mx.array: + def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array: """ Args: tokens: Sequence of input tokens. """ - logits = self.model(tokens, cache=cache) + logits = self.model(tokens, mask=mask, cache=cache) if "lm_head" in self: logits = self.lm_head(logits) else: diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 482bb324..0bbc2ca4 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -199,7 +199,10 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index d7e626f2..71c397f6 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -125,11 +125,13 @@ class Starcoder2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -152,9 +154,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 3097c522..7b4376bb 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_map from mlx_lm.models import rope_utils +from mlx_lm.models.base import create_causal_mask from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache @@ -128,6 +129,22 @@ class TestModels(unittest.TestCase): self.assertEqual(cache.offset, 22) self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def test_causal_mask_lengths(self): + mx.random.seed(8) + B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2) + lengths = mx.array([1, 2, 3, 1]) + q = mx.random.uniform(shape=(B, N_q, T_q, D)) + k = mx.random.uniform(shape=(B, N_kv, T_kv, D)) + v = k + mask = create_causal_mask(T_q, 0, lengths=lengths) + + out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + q[1, :, 2:] = mx.ones_like(q[1, :, 2:]) + k[1, :, 2:] = mx.ones_like(k[1, :, 2:]) + v[1, :, 2:] = mx.ones_like(v[1, :, 2:]) + out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2])) + def test_rope(self): rope = rope_utils.initialize_rope(32, base=100, traditional=False) self.assertTrue(isinstance(rope, nn.RoPE)) @@ -162,10 +179,16 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.dtype, t) cache = make_prompt_cache(model) - outputs = model(inputs, cache) + outputs = model(inputs, cache=cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) + if model_type != "mamba": + mask = create_causal_mask(inputs.shape[1], 0).astype(t) + outputs = model(inputs, mask=mask) + self.assertEqual(outputs.shape, (1, 2, vocab_size)) + self.assertEqual(outputs.dtype, t) + outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache) self.assertEqual(outputs.shape, (1, 1, vocab_size)) self.assertEqual(outputs.dtype, t)