diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index b74dbe3b..cbb5f2cb 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -131,6 +131,13 @@ class TransformerBlock(nn.Module): return out, cache +def create_additive_causal_mask(N: int, offset: int = 0): + rinds = mx.arange(offset + N) + linds = mx.arange(offset, offset + N) if offset else rinds + mask = linds[:, None] < rinds[None] + return mask * -1e9 + + class LlamaModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -153,7 +160,9 @@ class LlamaModel(nn.Module): mask = None if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = create_additive_causal_mask( + h.shape[1], cache[0][0].shape[2] if cache is not None else 0 + ) mask = mask.astype(h.dtype) if cache is None: