Unify attention mask in LLMs (#911)

* Unify attention mask creation in LLMs.

Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:

```
    mask = None
    if h.shape[1] > 1:
        mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
        mask = mask.astype(h.dtype)
```

This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.

Some of the models correctly implement the mask creation with code like this:

```
    mask = None
    if h.shape[1] > 1:
        mask = create_additive_causal_mask(
            h.shape[1], cache[0].offset if cache is not None else 0
        )
        mask = mask.astype(h.dtype)
```

This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.

* Allow batches in LLM key-value cache

The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.

This change removes the hard-coded batch size from the code that
resizes the key-value cache.

* Simplify causal mask creation

Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).

* Use old-style type annotation to avoid linter error
This commit is contained in:
otriscon 2024-07-25 19:45:22 -04:00 committed by GitHub
parent 7a3ab1620a
commit 46da74fea2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 76 additions and 138 deletions

View File

@ -1,14 +1,9 @@
import inspect
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
def create_additive_causal_mask(N: int, offset: int = 0):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
return mask * -1e9
import mlx.nn as nn
class KVCache:
@ -29,9 +24,10 @@ class KVCache:
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B = keys.shape[0]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
@ -60,3 +56,24 @@ class BaseModelArgs:
if k in inspect.signature(cls).parameters
}
)
def create_additive_causal_mask(N: int, offset: int = 0):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None):
T = h.shape[1]
if T > 1:
# Input consists of multiple tokens, create a causal mask so that prior
# tokens do not give attention to later tokens. If a cache is in place
# (because e.g. prompt reuse), offset the mask accordingly.
offset = cache[0].offset if cache is not None and cache[0] is not None else 0
mask = create_additive_causal_mask(T, offset)
mask = mask.astype(h.dtype)
else:
mask = None
return mask

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -157,10 +157,7 @@ class CohereModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -199,11 +199,7 @@ class DBRX(nn.Module):
):
h = self.wte(inputs)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache
from .base import BaseModelArgs, KVCache, create_attention_mask
from .switch_layers import SwitchGLU
@ -408,11 +408,7 @@ class DeepseekV2Model(nn.Module):
cache: Optional[KVCache] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -141,10 +141,7 @@ class GemmaModel(nn.Module):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -165,10 +165,7 @@ class GemmaModel(nn.Module):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_additive_causal_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -136,10 +136,7 @@ class GPT2Model(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_additive_causal_mask(
hidden_states.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(hidden_states.dtype)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_additive_causal_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -147,10 +147,7 @@ class GPTBigCodeModel(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_additive_causal_mask(
hidden_states.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(hidden_states.dtype)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_additive_causal_mask
from .base import BaseModelArgs, create_attention_mask
# Based on the transformers implementation at:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@ -150,12 +150,7 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.embed_in(inputs)
mask = None
if hidden_states.shape[1] > 1:
mask = create_additive_causal_mask(
hidden_states.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(hidden_states.dtype)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -195,10 +195,7 @@ class InternLM2Model(nn.Module):
):
h = self.tok_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_additive_causal_mask
from .base import BaseModelArgs, KVCache, create_attention_mask
@dataclass
@ -271,12 +271,7 @@ class LlamaModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -160,10 +160,7 @@ class MiniCPMModel(nn.Module):
):
h = self.embed_tokens(inputs) * self.args.scale_emb
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@ -164,11 +164,7 @@ class MixtralModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -5,7 +5,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
try:
import hf_olmo
@ -126,10 +126,7 @@ class Transformer(nn.Module):
):
h = self.wte(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)

View File

@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -180,10 +180,7 @@ class OpenELMModel(nn.Module):
):
h = self.token_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -5,7 +5,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -138,14 +138,12 @@ class PhiModel(nn.Module):
def __call__(self, x, cache):
x = self.embed_tokens(x)
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.layers)
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
for layer, c in zip(self.layers, cache):
x = layer(x, mask, c)
return self.final_layernorm(x)

View File

@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache
from .base import BaseModelArgs, KVCache, create_attention_mask
from .su_rope import SuScaledRotaryEmbedding
@ -172,10 +172,7 @@ class Phi3Model(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -6,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache
from .base import BaseModelArgs, KVCache, create_attention_mask
@dataclass
@ -263,10 +263,7 @@ class Phi3Model(nn.Module):
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -6,6 +6,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import create_attention_mask
from .switch_layers import SwitchMLP
@ -167,10 +168,7 @@ class Model(nn.Module):
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
return self.lm_head(y)

View File

@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -171,10 +171,7 @@ class PlamoModel(nn.Module):
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(self.embed_tokens.weight.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]

View File

@ -4,7 +4,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -122,11 +122,7 @@ class QwenModel(nn.Module):
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
mask = None
T = x.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(x.dtype)
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache
from .base import BaseModelArgs, KVCache, create_attention_mask
@dataclass
@ -151,10 +151,7 @@ class Qwen2Model(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache
from .base import BaseModelArgs, KVCache, create_attention_mask
from .switch_layers import SwitchGLU
@ -189,10 +189,7 @@ class Qwen2MoeModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -5,7 +5,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -198,11 +198,7 @@ class Model(nn.Module):
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache
from .base import BaseModelArgs, KVCache, create_attention_mask
@dataclass
@ -127,10 +127,7 @@ class Starcoder2Model(nn.Module):
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)