mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Length masking for batch inputs (#1173)
* length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix
This commit is contained in:
@@ -123,10 +123,12 @@ class ExaoneModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
mask = create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
@@ -149,9 +151,10 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, cache)
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.transformer.wte.as_linear(out)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user