mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Unify attention mask in LLMs (#911)
* Unify attention mask creation in LLMs. Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc code to create a mask for the attention mechanism. This usually takes the form: ``` mask = None if h.shape[1] > 1: mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) ``` This correctly creates a mask only if the input consists of more than one token. But this code assumes the multi-token input is at the beginning of inference. If, for example, we are evaluating multiple tokens because of speculative decoding or prompt cache reuse, this mask will not have the correct shape and and will cause the raising of an exception in the attention computation. Some of the models correctly implement the mask creation with code like this: ``` mask = None if h.shape[1] > 1: mask = create_additive_causal_mask( h.shape[1], cache[0].offset if cache is not None else 0 ) mask = mask.astype(h.dtype) ``` This commit unifies the attention mask creation for all models with a new function `create_attention_mask`, reducing code duplication and helping all models support inference performance enhancements like those mentioned above. * Allow batches in LLM key-value cache The current implementation of the LLM key-value cache assumes that the input batch is of size 1. Input batching (evaluating multiple alterative inputs at the same time) can be a valuable tool for speculative sampling and other techniques. This change removes the hard-coded batch size from the code that resizes the key-value cache. * Simplify causal mask creation Use the same codepath regardless of whether there's an offset or not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717). * Use old-style type annotation to avoid linter error
This commit is contained in:
parent
7a3ab1620a
commit
46da74fea2
@ -1,14 +1,9 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
return mask * -1e9
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class KVCache:
|
||||
@ -29,9 +24,10 @@ class KVCache:
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
B = keys.shape[0]
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
|
||||
v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
|
||||
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
|
||||
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
@ -60,3 +56,24 @@ class BaseModelArgs:
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None):
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
# Input consists of multiple tokens, create a causal mask so that prior
|
||||
# tokens do not give attention to later tokens. If a cache is in place
|
||||
# (because e.g. prompt reuse), offset the mask accordingly.
|
||||
offset = cache[0].offset if cache is not None and cache[0] is not None else 0
|
||||
mask = create_additive_causal_mask(T, offset)
|
||||
mask = mask.astype(h.dtype)
|
||||
else:
|
||||
mask = None
|
||||
return mask
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -157,10 +157,7 @@ class CohereModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -199,11 +199,7 @@ class DBRX(nn.Module):
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@ -408,11 +408,7 @@ class DeepseekV2Model(nn.Module):
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -141,10 +141,7 @@ class GemmaModel(nn.Module):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -165,10 +165,7 @@ class GemmaModel(nn.Module):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_additive_causal_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -136,10 +136,7 @@ class GPT2Model(nn.Module):
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
mask = create_additive_causal_mask(
|
||||
hidden_states.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(hidden_states.dtype)
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_additive_causal_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -147,10 +147,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
mask = create_additive_causal_mask(
|
||||
hidden_states.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(hidden_states.dtype)
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_additive_causal_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
# Based on the transformers implementation at:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
@ -150,12 +150,7 @@ class GPTNeoXModel(nn.Module):
|
||||
|
||||
hidden_states = self.embed_in(inputs)
|
||||
|
||||
mask = None
|
||||
if hidden_states.shape[1] > 1:
|
||||
mask = create_additive_causal_mask(
|
||||
hidden_states.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(hidden_states.dtype)
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -195,10 +195,7 @@ class InternLM2Model(nn.Module):
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_additive_causal_mask
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -271,12 +271,7 @@ class LlamaModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = create_additive_causal_mask(
|
||||
h.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -160,10 +160,7 @@ class MiniCPMModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs) * self.args.scale_emb
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@ -164,11 +164,7 @@ class MixtralModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -5,7 +5,7 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
try:
|
||||
import hf_olmo
|
||||
@ -126,10 +126,7 @@ class Transformer(nn.Module):
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -180,10 +180,7 @@ class OpenELMModel(nn.Module):
|
||||
):
|
||||
h = self.token_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -5,7 +5,7 @@ from typing import Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -138,14 +138,12 @@ class PhiModel(nn.Module):
|
||||
|
||||
def __call__(self, x, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, mask, c)
|
||||
return self.final_layernorm(x)
|
||||
|
@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .su_rope import SuScaledRotaryEmbedding
|
||||
|
||||
|
||||
@ -172,10 +172,7 @@ class Phi3Model(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -6,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -263,10 +263,7 @@ class Phi3Model(nn.Module):
|
||||
if self.mup_embedding_multiplier:
|
||||
h = self.mup_embedding_multiplier * h
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -6,6 +6,7 @@ from typing import Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import create_attention_mask
|
||||
from .switch_layers import SwitchMLP
|
||||
|
||||
|
||||
@ -167,10 +168,7 @@ class Model(nn.Module):
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
@ -5,7 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -171,10 +171,7 @@ class PlamoModel(nn.Module):
|
||||
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.layers.layers))]
|
||||
|
@ -4,7 +4,7 @@ from typing import Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -122,11 +122,7 @@ class QwenModel(nn.Module):
|
||||
def __call__(self, inputs, mask=None, cache=None):
|
||||
x = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
T = x.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(x.dtype)
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -151,10 +151,7 @@ class Qwen2Model(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@ -189,10 +189,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -5,7 +5,7 @@ from typing import Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -198,11 +198,7 @@ class Model(nn.Module):
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
mask = create_attention_mask(x, cache)
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -127,10 +127,7 @@ class Starcoder2Model(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
Loading…
Reference in New Issue
Block a user