mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
Length masking for batch inputs (#1173)
* length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix
This commit is contained in:
@@ -143,10 +143,11 @@ class PhiModel(nn.Module):
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, cache):
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
||||
mask = create_attention_mask(x, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -167,9 +168,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.model(x, cache)
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user