mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-03 23:14:34 +08:00
Length masking for batch inputs (#1173)
* length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix
This commit is contained in:
@@ -199,7 +199,10 @@ class Model(nn.Module):
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
|
Reference in New Issue
Block a user