mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Pad mask with zeros for non-square attention matrices (#715)
* Pad mask with zeros for non-square attention matrices The current implementation of the mask assumes the attention matrix is square, which is true if there is no cache. However, if one wishes to produce multiple tokens at a time, such as in speculative decoding implementations, a rectangular mask is necessary. This change pads the bottom of the mask with zeros so multi-token decoding with a cache works correctly. * Directly create mask instead of padding * Update llama.py
This commit is contained in:
@@ -131,6 +131,13 @@ class TransformerBlock(nn.Module):
|
||||
return out, cache
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@@ -153,7 +160,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = create_additive_causal_mask(
|
||||
h.shape[1], cache[0][0].shape[2] if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
|
Reference in New Issue
Block a user