Length masking for batch inputs (#1173)

* length masking

* add mask to mlx_lm model interface

* remove lengths

* fix test:

* comment + fix
This commit is contained in:
Alex Barron
2024-12-18 19:43:52 -08:00
committed by GitHub
parent db109184b7
commit d4ef909d4a
34 changed files with 191 additions and 72 deletions

View File

@@ -143,10 +143,11 @@ class PhiModel(nn.Module):
config.hidden_size, eps=config.layer_norm_eps
)
def __call__(self, x, cache):
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
mask = create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -167,9 +168,10 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
y = self.model(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@property