From c0019c4908068ac9a3a62d8f2e376575b93dbd26 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Sat, 4 May 2024 19:32:25 -0400 Subject: [PATCH] Pad mask with zeros for non-square attention matrices (#715) * Pad mask with zeros for non-square attention matrices The current implementation of the mask assumes the attention matrix is square, which is true if there is no cache. However, if one wishes to produce multiple tokens at a time, such as in speculative decoding implementations, a rectangular mask is necessary. This change pads the bottom of the mask with zeros so multi-token decoding with a cache works correctly. * Directly create mask instead of padding * Update llama.py --- llms/mlx_lm/models/llama.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index b74dbe3b..cbb5f2cb 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -131,6 +131,13 @@ class TransformerBlock(nn.Module): return out, cache +def create_additive_causal_mask(N: int, offset: int = 0): + rinds = mx.arange(offset + N) + linds = mx.arange(offset, offset + N) if offset else rinds + mask = linds[:, None] < rinds[None] + return mask * -1e9 + + class LlamaModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -153,7 +160,9 @@ class LlamaModel(nn.Module): mask = None if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = create_additive_causal_mask( + h.shape[1], cache[0][0].shape[2] if cache is not None else 0 + ) mask = mask.astype(h.dtype) if cache is None: