mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Length masking for batch inputs (#1173)
* length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix
This commit is contained in:
@@ -155,11 +155,13 @@ class CohereModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
@@ -180,9 +182,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
Reference in New Issue
Block a user