Length masking for batch inputs (#1173)

* length masking

* add mask to mlx_lm model interface

* remove lengths

* fix test:

* comment + fix
This commit is contained in:
Alex Barron
2024-12-18 19:43:52 -08:00
committed by GitHub
parent db109184b7
commit d4ef909d4a
34 changed files with 191 additions and 72 deletions

View File

@@ -389,6 +389,7 @@ class Griffin(nn.Module):
def __call__(
self,
tokens,
mask: mx.array = None,
cache=None,
):
x = self.embed_tokens(tokens)
@@ -402,7 +403,8 @@ class Griffin(nn.Module):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
if mask is None:
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
@@ -418,12 +420,12 @@ class Model(nn.Module):
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
logits = self.model(tokens, cache=cache)
logits = self.model(tokens, mask=mask, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else: