Length masking for batch inputs (#1173)

* length masking

* add mask to mlx_lm model interface

* remove lengths

* fix test:

* comment + fix
This commit is contained in:
Alex Barron
2024-12-18 19:43:52 -08:00
committed by GitHub
parent db109184b7
commit d4ef909d4a
34 changed files with 191 additions and 72 deletions

View File

@@ -6,7 +6,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache
@@ -151,16 +151,13 @@ class CohereModel(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
T = h.shape[1]
if T > 1:
offset = cache[0].offset if cache else 0
mask = create_causal_mask(T, offset).astype(h.dtype)
else:
mask = None
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
@@ -181,9 +178,10 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, cache)
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out