add mask to mlx_lm model interface

This commit is contained in:
Alex Barron 2024-12-18 13:54:14 -08:00
parent c5ce9a31f2
commit cd9dcf0383
34 changed files with 125 additions and 79 deletions

View File

@ -10,7 +10,6 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
@ -23,22 +22,17 @@ def make_prompt_cache(
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
lengths (Optional[array]): If provided these sequence lengths will be
used mask the KV cache. Useful for batch inputs.
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
cache = [
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
cache = [KVCache() for _ in range(num_layers)]
cache[0].lengths = lengths
return cache
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):

View File

@ -155,11 +155,12 @@ class CohereModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -180,9 +181,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

View File

@ -6,7 +6,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache
@ -151,16 +151,12 @@ class CohereModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
T = h.shape[1]
if T > 1:
offset = cache[0].offset if cache else 0
mask = create_causal_mask(T, offset).astype(h.dtype)
else:
mask = None
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -181,9 +177,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

View File

@ -197,11 +197,12 @@ class DBRX(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
@ -223,9 +224,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
return self.lm_head(out)
@property

View File

@ -211,9 +211,10 @@ class DeepseekModel(nn.Module):
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -236,8 +237,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):

View File

@ -370,9 +370,10 @@ class DeepseekV2Model(nn.Module):
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -395,8 +396,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):

View File

@ -123,10 +123,11 @@ class ExaoneModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.h)
@ -149,9 +150,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:

View File

@ -138,12 +138,13 @@ class GemmaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -164,9 +165,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
return out

View File

@ -160,12 +160,13 @@ class GemmaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -187,9 +188,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = mx.tanh(out / self.final_logit_softcapping)
out = out * self.final_logit_softcapping

View File

@ -126,6 +126,7 @@ class GPT2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
@ -138,7 +139,7 @@ class GPT2Model(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache)
mask = mask or create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@ -159,9 +160,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.wte.as_linear(out)
return out

View File

@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
B, L = inputs.shape
@ -149,7 +150,7 @@ class GPTBigCodeModel(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache)
mask = mask or create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@ -172,9 +173,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:

View File

@ -146,13 +146,14 @@ class GPTNeoXModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
hidden_states = self.embed_in(inputs)
mask = create_attention_mask(hidden_states, cache)
mask = mask or create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@ -176,9 +177,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return out
def sanitize(self, weights):

View File

@ -239,11 +239,12 @@ class HunYuanModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -266,9 +267,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights):

View File

@ -193,11 +193,12 @@ class InternLM2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -220,9 +221,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.tok_embeddings.as_linear(out)
else:

View File

@ -155,11 +155,12 @@ class LlamaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -182,9 +183,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@ -158,11 +158,12 @@ class MiniCPMModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs) * self.args.scale_emb
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -186,9 +187,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if not self.args.tie_word_embeddings:
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))

View File

@ -162,11 +162,12 @@ class MixtralModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -188,9 +189,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@ -176,11 +176,12 @@ class NemotronModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -203,9 +204,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@ -124,11 +124,12 @@ class Transformer(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
@ -152,9 +153,10 @@ class OlmoModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.transformer(inputs, cache)
return self.transformer(inputs, mask, cache)
class Model(nn.Module):
@ -167,9 +169,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.model(inputs, cache)
return self.model(inputs, mask, cache)
@property
def layers(self):

View File

@ -163,10 +163,11 @@ class LlamaModel(nn.Module):
self,
inputs: mx.array,
cache=None,
mask=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -190,8 +191,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache=None,
mask=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@ -178,11 +178,12 @@ class OpenELMModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.token_embeddings(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -205,9 +206,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out)
else:

View File

@ -143,10 +143,10 @@ class PhiModel(nn.Module):
config.hidden_size, eps=config.layer_norm_eps
)
def __call__(self, x, cache):
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
mask = create_attention_mask(x, cache)
mask = mask or create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -167,9 +167,10 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
y = self.model(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@property

View File

@ -168,11 +168,12 @@ class Phi3Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -194,9 +195,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
@property

View File

@ -258,13 +258,14 @@ class Phi3Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -290,9 +291,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
if self.mup_width_multiplier:
out = out / self.mup_width_multiplier

View File

@ -155,11 +155,12 @@ class PhiMoEModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -181,9 +182,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@ -175,7 +175,7 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
mask = mask or create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
return self.lm_head(y)

View File

@ -174,10 +174,11 @@ class PlamoModel(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]
@ -202,8 +203,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
@property

View File

@ -123,7 +123,7 @@ class QwenModel(nn.Module):
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
mask = create_attention_mask(x, cache)
mask = mask or create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -149,11 +149,12 @@ class Qwen2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -176,9 +177,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@ -187,11 +187,12 @@ class Qwen2MoeModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -213,9 +214,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@ -389,6 +389,7 @@ class Griffin(nn.Module):
def __call__(
self,
tokens,
mask: mx.array = None,
cache=None,
):
x = self.embed_tokens(tokens)
@ -402,7 +403,7 @@ class Griffin(nn.Module):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
mask = mask or create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
@ -418,12 +419,12 @@ class Model(nn.Module):
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
logits = self.model(tokens, cache=cache)
logits = self.model(tokens, mask=mask, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else:

View File

@ -199,7 +199,7 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
mask = mask or create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)

View File

@ -125,11 +125,12 @@ class Starcoder2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = mask or create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@ -152,9 +153,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@ -182,7 +182,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.dtype, t)
cache = make_prompt_cache(model)
outputs = model(inputs, cache)
outputs = model(inputs, cache=cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)