Pad mask with zeros for non-square attention matrices (#715)

* Pad mask with zeros for non-square attention matrices

The current implementation of the mask assumes the attention matrix is square, which is true if there is no cache. However, if one wishes to produce multiple tokens at a time, such as in speculative decoding implementations, a rectangular mask is necessary.

This change pads the bottom of the mask with zeros so multi-token decoding with a cache works correctly.

* Directly create mask instead of padding

* Update llama.py
This commit is contained in:
Kevin Wang 2024-05-04 19:32:25 -04:00 committed by GitHub
parent f30413b63c
commit c0019c4908
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -131,6 +131,13 @@ class TransformerBlock(nn.Module):
return out, cache return out, cache
def create_additive_causal_mask(N: int, offset: int = 0):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
return mask * -1e9
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -153,7 +160,9 @@ class LlamaModel(nn.Module):
mask = None mask = None
if h.shape[1] > 1: if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = create_additive_causal_mask(
h.shape[1], cache[0][0].shape[2] if cache is not None else 0
)
mask = mask.astype(h.dtype) mask = mask.astype(h.dtype)
if cache is None: if cache is None: