From cd9dcf038348dd0fbe205accced4a10a0ace95ae Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 18 Dec 2024 13:54:14 -0800 Subject: [PATCH] add mask to mlx_lm model interface --- llms/mlx_lm/models/cache.py | 10 ++-------- llms/mlx_lm/models/cohere.py | 6 ++++-- llms/mlx_lm/models/cohere2.py | 13 +++++-------- llms/mlx_lm/models/dbrx.py | 6 ++++-- llms/mlx_lm/models/deepseek.py | 6 ++++-- llms/mlx_lm/models/deepseek_v2.py | 6 ++++-- llms/mlx_lm/models/exaone.py | 6 ++++-- llms/mlx_lm/models/gemma.py | 6 ++++-- llms/mlx_lm/models/gemma2.py | 6 ++++-- llms/mlx_lm/models/gpt2.py | 6 ++++-- llms/mlx_lm/models/gpt_bigcode.py | 6 ++++-- llms/mlx_lm/models/gpt_neox.py | 6 ++++-- llms/mlx_lm/models/hunyuan.py | 6 ++++-- llms/mlx_lm/models/internlm2.py | 6 ++++-- llms/mlx_lm/models/llama.py | 6 ++++-- llms/mlx_lm/models/minicpm.py | 6 ++++-- llms/mlx_lm/models/mixtral.py | 6 ++++-- llms/mlx_lm/models/nemotron.py | 6 ++++-- llms/mlx_lm/models/olmo.py | 9 ++++++--- llms/mlx_lm/models/olmo2.py | 6 ++++-- llms/mlx_lm/models/openelm.py | 6 ++++-- llms/mlx_lm/models/phi.py | 7 ++++--- llms/mlx_lm/models/phi3.py | 6 ++++-- llms/mlx_lm/models/phi3small.py | 6 ++++-- llms/mlx_lm/models/phimoe.py | 6 ++++-- llms/mlx_lm/models/phixtral.py | 2 +- llms/mlx_lm/models/plamo.py | 6 ++++-- llms/mlx_lm/models/qwen.py | 2 +- llms/mlx_lm/models/qwen2.py | 6 ++++-- llms/mlx_lm/models/qwen2_moe.py | 6 ++++-- llms/mlx_lm/models/recurrent_gemma.py | 7 ++++--- llms/mlx_lm/models/stablelm.py | 2 +- llms/mlx_lm/models/starcoder2.py | 6 ++++-- llms/tests/test_models.py | 2 +- 34 files changed, 125 insertions(+), 79 deletions(-) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 1e311381..81b16af3 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -10,7 +10,6 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, - lengths: Optional[mx.array] = None, ) -> List[Any]: """ Construct the model's cache for use when cgeneration. @@ -23,22 +22,17 @@ def make_prompt_cache( max_kv_size (Optional[int]): If provided and the model does not have a ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum size of ``max_kv_size`` - lengths (Optional[array]): If provided these sequence lengths will be - used mask the KV cache. Useful for batch inputs. """ if hasattr(model, "make_cache"): return model.make_cache() num_layers = len(model.layers) if max_kv_size is not None: - cache = [ + return [ RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) ] else: - cache = [KVCache() for _ in range(num_layers)] - - cache[0].lengths = lengths - return cache + return [KVCache() for _ in range(num_layers)] def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 7e002b0c..89d64208 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -155,11 +155,12 @@ class CohereModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -180,9 +181,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index fcb4061b..ee74fce1 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import KVCache, RotatingKVCache @@ -151,16 +151,12 @@ class CohereModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - T = h.shape[1] - if T > 1: - offset = cache[0].offset if cache else 0 - mask = create_causal_mask(T, offset).astype(h.dtype) - else: - mask = None + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -181,9 +177,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 7be274cc..73a96810 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -197,11 +197,12 @@ class DBRX(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -223,9 +224,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index b7b24dba..c97afdcb 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -211,9 +211,10 @@ class DeepseekModel(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -236,8 +237,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 444813b9..aad8bdff 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -370,9 +370,10 @@ class DeepseekV2Model(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -395,8 +396,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py index eaed5dd8..ce4a649e 100644 --- a/llms/mlx_lm/models/exaone.py +++ b/llms/mlx_lm/models/exaone.py @@ -123,10 +123,11 @@ class ExaoneModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.h) @@ -149,9 +150,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 3f384c3f..bf5a68db 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -138,12 +138,13 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -164,9 +165,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) return out diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 64951ae4..df1fdb3a 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -160,12 +160,13 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -187,9 +188,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = mx.tanh(out / self.final_logit_softcapping) out = out * self.final_logit_softcapping diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 52076a34..706e2df6 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -126,6 +126,7 @@ class GPT2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape @@ -138,7 +139,7 @@ class GPT2Model(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_attention_mask(hidden_states, cache) + mask = mask or create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -159,9 +160,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.wte.as_linear(out) return out diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 23e86e20..e7760ba1 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): B, L = inputs.shape @@ -149,7 +150,7 @@ class GPTBigCodeModel(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_attention_mask(hidden_states, cache) + mask = mask or create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -172,9 +173,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index ccb0b28b..327ff847 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -146,13 +146,14 @@ class GPTNeoXModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape hidden_states = self.embed_in(inputs) - mask = create_attention_mask(hidden_states, cache) + mask = mask or create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -176,9 +177,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return out def sanitize(self, weights): diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py index b098c20d..0a34957a 100644 --- a/llms/mlx_lm/models/hunyuan.py +++ b/llms/mlx_lm/models/hunyuan.py @@ -239,11 +239,12 @@ class HunYuanModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -266,9 +267,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.model.embed_tokens.as_linear(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index f5ce057e..c802a8f9 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -193,11 +193,12 @@ class InternLM2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.tok_embeddings(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -220,9 +221,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.tok_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 290cb83e..625ae541 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -155,11 +155,12 @@ class LlamaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -182,9 +183,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 907beb2a..79fa4f16 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -158,11 +158,12 @@ class MiniCPMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) * self.args.scale_emb - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -186,9 +187,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if not self.args.tie_word_embeddings: out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index dd94d1f4..ec0253a3 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -162,11 +162,12 @@ class MixtralModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -188,9 +189,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index f73c0277..2d69b0eb 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -176,11 +176,12 @@ class NemotronModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -203,9 +204,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 3627df06..cc382876 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -124,11 +124,12 @@ class Transformer(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -152,9 +153,10 @@ class OlmoModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.transformer(inputs, cache) + return self.transformer(inputs, mask, cache) class Model(nn.Module): @@ -167,9 +169,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.model(inputs, cache) + return self.model(inputs, mask, cache) @property def layers(self): diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py index 64d7e116..ee19cf0e 100644 --- a/llms/mlx_lm/models/olmo2.py +++ b/llms/mlx_lm/models/olmo2.py @@ -163,10 +163,11 @@ class LlamaModel(nn.Module): self, inputs: mx.array, cache=None, + mask=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -190,8 +191,9 @@ class Model(nn.Module): self, inputs: mx.array, cache=None, + mask=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 408802f4..d0b5a48b 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -178,11 +178,12 @@ class OpenELMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.token_embeddings(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -205,9 +206,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.share_input_output_layers: out = self.transformer.token_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 510025ea..06b49142 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -143,10 +143,10 @@ class PhiModel(nn.Module): config.hidden_size, eps=config.layer_norm_eps ) - def __call__(self, x, cache): + def __call__(self, x, mask, cache): x = self.embed_tokens(x) - mask = create_attention_mask(x, cache) + mask = mask or create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.layers) @@ -167,9 +167,10 @@ class Model(nn.Module): def __call__( self, x: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: - y = self.model(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) @property diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index ee6efc49..7d9edfba 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -168,11 +168,12 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -194,9 +195,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 53e1a638..d43ca73e 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -258,13 +258,14 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) if self.mup_embedding_multiplier: h = self.mup_embedding_multiplier * h - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -290,9 +291,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) if self.mup_width_multiplier: out = out / self.mup_width_multiplier diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index f42a6dd0..14eade9c 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -155,11 +155,12 @@ class PhiMoEModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -181,9 +182,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 42d647b0..6c0e5750 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -175,7 +175,7 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + mask = mask or create_attention_mask(x, cache) y = self.transformer(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index c8e5bf50..080a916f 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -174,10 +174,11 @@ class PlamoModel(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None for _ in range(len(self.layers.layers))] @@ -202,8 +203,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 8145a890..03218dde 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -123,7 +123,7 @@ class QwenModel(nn.Module): def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) - mask = create_attention_mask(x, cache) + mask = mask or create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index fac59d78..c956cc47 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -149,11 +149,12 @@ class Qwen2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -176,9 +177,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 167fc5dd..dbef3d00 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -187,11 +187,12 @@ class Qwen2MoeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -213,9 +214,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 49e4bb8f..7bd76ada 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -389,6 +389,7 @@ class Griffin(nn.Module): def __call__( self, tokens, + mask: mx.array = None, cache=None, ): x = self.embed_tokens(tokens) @@ -402,7 +403,7 @@ class Griffin(nn.Module): if block.temporal_block_type != "recurrent": mask_cache = [cache[i]] - mask = create_attention_mask(x, mask_cache) + mask = mask or create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) @@ -418,12 +419,12 @@ class Model(nn.Module): self.model_type = config.model_type self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - def __call__(self, tokens: mx.array, cache=None) -> mx.array: + def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array: """ Args: tokens: Sequence of input tokens. """ - logits = self.model(tokens, cache=cache) + logits = self.model(tokens, mask=mask, cache=cache) if "lm_head" in self: logits = self.lm_head(logits) else: diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 482bb324..67deef5b 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -199,7 +199,7 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + mask = mask or create_attention_mask(x, cache) y = self.model(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index d7e626f2..2a2616d2 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -125,11 +125,12 @@ class Starcoder2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = mask or create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -152,9 +153,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 6ae7d803..61dd8c58 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -182,7 +182,7 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.dtype, t) cache = make_prompt_cache(model) - outputs = model(inputs, cache) + outputs = model(inputs, cache=cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t)