mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Length masking for batch inputs (#1173)
* length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix
This commit is contained in:
@@ -174,10 +174,12 @@ class PlamoModel(nn.Module):
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.layers.layers))]
|
||||
@@ -202,8 +204,9 @@ class Model(nn.Module):
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user