mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
* Unify attention mask creation in LLMs.
Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:
```
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
```
This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.
Some of the models correctly implement the mask creation with code like this:
```
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)
```
This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.
* Allow batches in LLM key-value cache
The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.
This change removes the hard-coded batch size from the code that
resizes the key-value cache.
* Simplify causal mask creation
Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).
* Use old-style type annotation to avoid linter error
205 lines
6.0 KiB
Python
205 lines
6.0 KiB
Python
from dataclasses import dataclass
|
|
from typing import Dict, Optional, Tuple, Union
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import numpy as np
|
|
|
|
from .base import BaseModelArgs, create_attention_mask
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs(BaseModelArgs):
|
|
model_type: str
|
|
n_ctx: int
|
|
n_embd: int
|
|
n_head: int
|
|
n_layer: int
|
|
n_positions: int
|
|
layer_norm_epsilon: float
|
|
vocab_size: int
|
|
num_key_value_heads: int = None
|
|
|
|
def __post_init__(self):
|
|
if self.num_key_value_heads is None:
|
|
self.num_key_value_heads = self.n_head
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
|
|
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
|
|
|
|
self.n_embd = args.n_embd
|
|
self.n_head = args.n_head
|
|
self.head_dim = self.n_embd // self.n_head
|
|
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
|
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
) -> mx.array:
|
|
B, L, D = x.shape
|
|
|
|
qkv = self.c_attn(x)
|
|
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
|
|
|
# Prepare the queries, keys and values for the attention computation
|
|
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
|
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
|
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
|
|
|
if cache is not None:
|
|
keys, values = cache.update_and_fetch(keys, values)
|
|
|
|
output = mx.fast.scaled_dot_product_attention(
|
|
queries, keys, values, scale=self.scale, mask=mask
|
|
)
|
|
|
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
return self.c_proj(output)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
|
|
self.n_embd = args.n_embd
|
|
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
|
|
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)
|
|
|
|
def __call__(self, x) -> mx.array:
|
|
return self.c_proj(nn.gelu_approx(self.c_fc(x)))
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
|
|
self.n_head = args.n_head
|
|
self.n_embd = args.n_embd
|
|
self.layer_norm_epsilon = args.layer_norm_epsilon
|
|
self.attn = Attention(args)
|
|
self.mlp = MLP(args)
|
|
self.ln_1 = nn.LayerNorm(
|
|
self.n_embd,
|
|
eps=self.layer_norm_epsilon,
|
|
)
|
|
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
) -> mx.array:
|
|
r = self.attn(self.ln_1(x), mask, cache)
|
|
h = x + r
|
|
r = self.mlp(self.ln_2(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
class GPT2Model(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.n_embd = args.n_embd
|
|
self.n_positions = args.n_positions
|
|
self.vocab_size = args.vocab_size
|
|
self.n_layer = args.n_layer
|
|
self.layer_norm_epsilon = args.layer_norm_epsilon
|
|
assert self.vocab_size > 0
|
|
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
|
|
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
|
|
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
|
|
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache=None,
|
|
):
|
|
_, L = inputs.shape
|
|
|
|
hidden_states = self.wte(inputs)
|
|
|
|
mask = None
|
|
if hidden_states.shape[1] > 1:
|
|
|
|
position_ids = mx.array(np.arange(L))
|
|
hidden_states += self.wpe(position_ids)
|
|
|
|
mask = create_attention_mask(hidden_states, cache)
|
|
|
|
if cache is None:
|
|
cache = [None] * len(self.h)
|
|
|
|
for layer, c in zip(self.h, cache):
|
|
hidden_states = layer(hidden_states, mask, cache=c)
|
|
|
|
return self.ln_f(hidden_states)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.model_type = args.model_type
|
|
self.model = GPT2Model(args)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache=None,
|
|
):
|
|
out = self.model(inputs, cache)
|
|
out = self.model.wte.as_linear(out)
|
|
return out
|
|
|
|
def sanitize(self, weights):
|
|
new_weights = {}
|
|
for i in range(self.args.n_layer):
|
|
if f"h.{i}.attn.bias" in weights:
|
|
del weights[f"h.{i}.attn.bias"]
|
|
if f"h.{i}.attn.c_attn.weight" in weights:
|
|
weights[f"h.{i}.attn.c_attn.weight"] = weights[
|
|
f"h.{i}.attn.c_attn.weight"
|
|
].transpose(1, 0)
|
|
if f"h.{i}.attn.c_proj.weight" in weights:
|
|
weights[f"h.{i}.attn.c_proj.weight"] = weights[
|
|
f"h.{i}.attn.c_proj.weight"
|
|
].transpose(1, 0)
|
|
if f"h.{i}.mlp.c_fc.weight" in weights:
|
|
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
|
|
f"h.{i}.mlp.c_fc.weight"
|
|
].transpose(1, 0)
|
|
if f"h.{i}.mlp.c_proj.weight" in weights:
|
|
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
|
|
f"h.{i}.mlp.c_proj.weight"
|
|
].transpose(1, 0)
|
|
for weight in weights:
|
|
if not weight.startswith("model."):
|
|
new_weights[f"model.{weight}"] = weights[weight]
|
|
else:
|
|
new_weights[weight] = weights[weight]
|
|
return new_weights
|
|
|
|
@property
|
|
def layers(self):
|
|
return self.model.h
|
|
|
|
@property
|
|
def head_dim(self):
|
|
return self.args.n_embd // self.args.n_head
|
|
|
|
@property
|
|
def n_kv_heads(self):
|
|
return self.args.num_key_value_heads
|