Length masking for batch inputs (#1173)

* length masking

* add mask to mlx_lm model interface

* remove lengths

* fix test:

* comment + fix
This commit is contained in:
Alex Barron
2024-12-18 19:43:52 -08:00
committed by GitHub
parent db109184b7
commit d4ef909d4a
34 changed files with 191 additions and 72 deletions

View File

@@ -23,7 +23,12 @@ class BaseModelArgs:
)
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
def create_causal_mask(
N: int,
offset: int = 0,
window_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
@@ -31,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
if lengths is not None:
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9

View File

@@ -155,11 +155,13 @@ class CohereModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -180,9 +182,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

View File

@@ -6,7 +6,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache
@@ -151,16 +151,13 @@ class CohereModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
T = h.shape[1]
if T > 1:
offset = cache[0].offset if cache else 0
mask = create_causal_mask(T, offset).astype(h.dtype)
else:
mask = None
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -181,9 +178,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

View File

@@ -197,11 +197,13 @@ class DBRX(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
@@ -223,9 +225,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
return self.lm_head(out)
@property

View File

@@ -211,9 +211,11 @@ class DeepseekModel(nn.Module):
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -236,8 +238,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -370,9 +370,12 @@ class DeepseekV2Model(nn.Module):
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -395,8 +398,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -123,10 +123,12 @@ class ExaoneModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -149,9 +151,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:

View File

@@ -138,12 +138,14 @@ class GemmaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -164,9 +166,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
return out

View File

@@ -160,12 +160,14 @@ class GemmaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -187,9 +189,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = mx.tanh(out / self.final_logit_softcapping)
out = out * self.final_logit_softcapping

View File

@@ -126,6 +126,7 @@ class GPT2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
@@ -138,7 +139,8 @@ class GPT2Model(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -159,9 +161,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.wte.as_linear(out)
return out

View File

@@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
B, L = inputs.shape
@@ -149,7 +150,8 @@ class GPTBigCodeModel(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -172,9 +174,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:

View File

@@ -146,13 +146,15 @@ class GPTNeoXModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
hidden_states = self.embed_in(inputs)
mask = create_attention_mask(hidden_states, cache)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
@@ -176,9 +178,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return out
def sanitize(self, weights):

View File

@@ -239,11 +239,13 @@ class HunYuanModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -266,9 +268,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights):

View File

@@ -193,11 +193,13 @@ class InternLM2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -220,9 +222,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.tok_embeddings.as_linear(out)
else:

View File

@@ -155,11 +155,13 @@ class LlamaModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -182,9 +184,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -158,11 +158,13 @@ class MiniCPMModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs) * self.args.scale_emb
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -186,9 +188,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if not self.args.tie_word_embeddings:
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))

View File

@@ -162,11 +162,13 @@ class MixtralModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -188,9 +190,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -176,11 +176,13 @@ class NemotronModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -203,9 +205,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -124,11 +124,13 @@ class Transformer(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
@@ -152,9 +154,10 @@ class OlmoModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.transformer(inputs, cache)
return self.transformer(inputs, mask, cache)
class Model(nn.Module):
@@ -167,9 +170,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.model(inputs, cache)
return self.model(inputs, mask, cache)
@property
def layers(self):

View File

@@ -163,10 +163,12 @@ class LlamaModel(nn.Module):
self,
inputs: mx.array,
cache=None,
mask=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -190,8 +192,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache=None,
mask=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -178,11 +178,13 @@ class OpenELMModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.token_embeddings(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -205,9 +207,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, cache)
out = self.transformer(inputs, mask, cache)
if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out)
else:

View File

@@ -143,10 +143,11 @@ class PhiModel(nn.Module):
config.hidden_size, eps=config.layer_norm_eps
)
def __call__(self, x, cache):
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
mask = create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -167,9 +168,10 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
y = self.model(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@property

View File

@@ -168,11 +168,13 @@ class Phi3Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -194,9 +196,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
@property

View File

@@ -258,13 +258,15 @@ class Phi3Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -290,9 +292,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
if self.mup_width_multiplier:
out = out / self.mup_width_multiplier

View File

@@ -155,11 +155,13 @@ class PhiMoEModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -181,9 +183,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -175,7 +175,9 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
return self.lm_head(y)

View File

@@ -174,10 +174,12 @@ class PlamoModel(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]
@@ -202,8 +204,9 @@ class Model(nn.Module):
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
out = self.model(inputs, cache)
out = self.model(inputs, cache, mask)
return self.lm_head(out)
@property

View File

@@ -123,7 +123,8 @@ class QwenModel(nn.Module):
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
mask = create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@@ -149,11 +149,13 @@ class Qwen2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -176,9 +178,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:

View File

@@ -187,11 +187,13 @@ class Qwen2MoeModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -213,9 +215,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):

View File

@@ -389,6 +389,7 @@ class Griffin(nn.Module):
def __call__(
self,
tokens,
mask: mx.array = None,
cache=None,
):
x = self.embed_tokens(tokens)
@@ -402,7 +403,8 @@ class Griffin(nn.Module):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
if mask is None:
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
@@ -418,12 +420,12 @@ class Model(nn.Module):
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
logits = self.model(tokens, cache=cache)
logits = self.model(tokens, mask=mask, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else:

View File

@@ -199,7 +199,10 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)

View File

@@ -125,11 +125,13 @@ class Starcoder2Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -152,9 +154,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else: