diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index b5fee238..ad7a4a65 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -37,10 +37,8 @@ def create_causal_mask( if window_size is not None: mask = mask | (linds > rinds + window_size) if lengths is not None: - mask = mx.repeat(mask[None], lengths.shape[0], axis=0) - lengths = lengths[:, None, None] - mask = mask | (rinds[None] >= lengths) - mask = mask[:, None] + lengths = lengths[:, None, None, None] + mask = mask | (rinds >= lengths) return mask * -1e9 diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 89d64208..b2d16dd7 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -160,7 +160,8 @@ class CohereModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index ee74fce1..ec0e9276 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -156,7 +156,8 @@ class CohereModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 73a96810..886b5630 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -202,7 +202,8 @@ class DBRX(nn.Module): ): h = self.wte(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index c97afdcb..ffc30c36 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -214,7 +214,8 @@ class DeepseekModel(nn.Module): mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index aad8bdff..9027da7e 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -373,7 +373,9 @@ class DeepseekV2Model(nn.Module): mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = mask or create_attention_mask(h, cache) + + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py index ce4a649e..ee3ed1e8 100644 --- a/llms/mlx_lm/models/exaone.py +++ b/llms/mlx_lm/models/exaone.py @@ -127,7 +127,8 @@ class ExaoneModel(nn.Module): cache=None, ): h = self.wte(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index bf5a68db..0860ddeb 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -144,7 +144,8 @@ class GemmaModel(nn.Module): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index df1fdb3a..321a58ff 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -166,7 +166,8 @@ class GemmaModel(nn.Module): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 706e2df6..5b277734 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -139,7 +139,8 @@ class GPT2Model(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = mask or create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index e7760ba1..8415c59e 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -150,7 +150,8 @@ class GPTBigCodeModel(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = mask or create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index 327ff847..5e124a67 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -153,7 +153,8 @@ class GPTNeoXModel(nn.Module): hidden_states = self.embed_in(inputs) - mask = mask or create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py index 0a34957a..f9dc5652 100644 --- a/llms/mlx_lm/models/hunyuan.py +++ b/llms/mlx_lm/models/hunyuan.py @@ -244,7 +244,8 @@ class HunYuanModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index c802a8f9..28a095e1 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -198,7 +198,8 @@ class InternLM2Model(nn.Module): ): h = self.tok_embeddings(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 625ae541..7b452ea4 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -160,7 +160,8 @@ class LlamaModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 79fa4f16..edddd583 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -163,7 +163,8 @@ class MiniCPMModel(nn.Module): ): h = self.embed_tokens(inputs) * self.args.scale_emb - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index ec0253a3..0afd1235 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -167,7 +167,8 @@ class MixtralModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index 2d69b0eb..eabfac8c 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -181,7 +181,8 @@ class NemotronModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index cc382876..4273b0ec 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -129,7 +129,8 @@ class Transformer(nn.Module): ): h = self.wte(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py index ee19cf0e..510ff882 100644 --- a/llms/mlx_lm/models/olmo2.py +++ b/llms/mlx_lm/models/olmo2.py @@ -167,7 +167,8 @@ class LlamaModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index d0b5a48b..504fe95c 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -183,7 +183,8 @@ class OpenELMModel(nn.Module): ): h = self.token_embeddings(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 06b49142..e9724691 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -146,7 +146,8 @@ class PhiModel(nn.Module): def __call__(self, x, mask, cache): x = self.embed_tokens(x) - mask = mask or create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 7d9edfba..d1c21e25 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -173,7 +173,8 @@ class Phi3Model(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index d43ca73e..cd566eec 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -265,7 +265,8 @@ class Phi3Model(nn.Module): if self.mup_embedding_multiplier: h = self.mup_embedding_multiplier * h - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index 14eade9c..bddcb128 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -160,7 +160,8 @@ class PhiMoEModel(nn.Module): ) -> mx.array: h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 6c0e5750..5477c2c0 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -175,7 +175,9 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = mask or create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 080a916f..9107daad 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -178,7 +178,8 @@ class PlamoModel(nn.Module): ) -> mx.array: h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None for _ in range(len(self.layers.layers))] diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 03218dde..ec8a0199 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -123,7 +123,8 @@ class QwenModel(nn.Module): def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) - mask = mask or create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index c956cc47..381767c4 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -154,7 +154,8 @@ class Qwen2Model(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index dbef3d00..c6aba622 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -192,7 +192,8 @@ class Qwen2MoeModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 7bd76ada..ad07d925 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -403,7 +403,8 @@ class Griffin(nn.Module): if block.temporal_block_type != "recurrent": mask_cache = [cache[i]] - mask = mask or create_attention_mask(x, mask_cache) + if mask is None: + mask = create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 67deef5b..0bbc2ca4 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -199,7 +199,10 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = mask or create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 2a2616d2..71c397f6 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -130,7 +130,8 @@ class Starcoder2Model(nn.Module): ): h = self.embed_tokens(inputs) - mask = mask or create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 00c01436..7b4376bb 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -183,6 +183,12 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) + if model_type != "mamba": + mask = create_causal_mask(inputs.shape[1], 0).astype(t) + outputs = model(inputs, mask=mask) + self.assertEqual(outputs.shape, (1, 2, vocab_size)) + self.assertEqual(outputs.dtype, t) + outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache) self.assertEqual(outputs.shape, (1, 1, vocab_size)) self.assertEqual(outputs.dtype, t)