mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00

* Pad mask with zeros for non-square attention matrices The current implementation of the mask assumes the attention matrix is square, which is true if there is no cache. However, if one wishes to produce multiple tokens at a time, such as in speculative decoding implementations, a rectangular mask is necessary. This change pads the bottom of the mask with zeros so multi-token decoding with a cache works correctly. * Directly create mask instead of padding * Update llama.py
201 lines
6.4 KiB
Python
201 lines
6.4 KiB
Python
from dataclasses import dataclass
|
|
from typing import Dict, Optional, Tuple, Union
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .base import BaseModelArgs
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs(BaseModelArgs):
|
|
model_type: str
|
|
hidden_size: int
|
|
num_hidden_layers: int
|
|
intermediate_size: int
|
|
num_attention_heads: int
|
|
rms_norm_eps: float
|
|
vocab_size: int
|
|
num_key_value_heads: int = None
|
|
rope_theta: float = 10000
|
|
rope_traditional: bool = False
|
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
|
|
|
def __post_init__(self):
|
|
if self.num_key_value_heads is None:
|
|
self.num_key_value_heads = self.num_attention_heads
|
|
|
|
if self.rope_scaling:
|
|
required_keys = {"factor", "type"}
|
|
if not all(key in self.rope_scaling for key in required_keys):
|
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
|
|
|
if self.rope_scaling["type"] != "linear":
|
|
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
|
|
dim = args.hidden_size
|
|
self.n_heads = n_heads = args.num_attention_heads
|
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
|
|
head_dim = args.hidden_size // n_heads
|
|
self.scale = head_dim**-0.5
|
|
|
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
|
|
|
rope_scale = (
|
|
1 / args.rope_scaling["factor"]
|
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
|
else 1
|
|
)
|
|
self.rope = nn.RoPE(
|
|
head_dim,
|
|
traditional=args.rope_traditional,
|
|
base=args.rope_theta,
|
|
scale=rope_scale,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
) -> mx.array:
|
|
B, L, D = x.shape
|
|
|
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
|
|
# Prepare the queries, keys and values for the attention computation
|
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
if cache is not None:
|
|
key_cache, value_cache = cache
|
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
values = mx.concatenate([value_cache, values], axis=2)
|
|
else:
|
|
queries = self.rope(queries)
|
|
keys = self.rope(keys)
|
|
|
|
output = mx.fast.scaled_dot_product_attention(
|
|
queries, keys, values, scale=self.scale, mask=mask
|
|
)
|
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
return self.o_proj(output), (keys, values)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, dim, hidden_dim):
|
|
super().__init__()
|
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
|
|
def __call__(self, x) -> mx.array:
|
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.num_attention_heads = args.num_attention_heads
|
|
self.hidden_size = args.hidden_size
|
|
self.self_attn = Attention(args)
|
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
self.post_attention_layernorm = nn.RMSNorm(
|
|
args.hidden_size, eps=args.rms_norm_eps
|
|
)
|
|
self.args = args
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
) -> mx.array:
|
|
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
|
h = x + r
|
|
r = self.mlp(self.post_attention_layernorm(h))
|
|
out = h + r
|
|
return out, cache
|
|
|
|
|
|
def create_additive_causal_mask(N: int, offset: int = 0):
|
|
rinds = mx.arange(offset + N)
|
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
|
mask = linds[:, None] < rinds[None]
|
|
return mask * -1e9
|
|
|
|
|
|
class LlamaModel(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.vocab_size = args.vocab_size
|
|
self.num_hidden_layers = args.num_hidden_layers
|
|
assert self.vocab_size > 0
|
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
self.layers = [
|
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
|
]
|
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache=None,
|
|
):
|
|
h = self.embed_tokens(inputs)
|
|
|
|
mask = None
|
|
if h.shape[1] > 1:
|
|
mask = create_additive_causal_mask(
|
|
h.shape[1], cache[0][0].shape[2] if cache is not None else 0
|
|
)
|
|
mask = mask.astype(h.dtype)
|
|
|
|
if cache is None:
|
|
cache = [None] * len(self.layers)
|
|
|
|
for e, layer in enumerate(self.layers):
|
|
h, cache[e] = layer(h, mask, cache[e])
|
|
|
|
return self.norm(h), cache
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.model_type = args.model_type
|
|
self.model = LlamaModel(args)
|
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache=None,
|
|
):
|
|
out, cache = self.model(inputs, cache)
|
|
return self.lm_head(out), cache
|
|
|
|
def sanitize(self, weights):
|
|
# Remove unused precomputed rotary freqs
|
|
return {
|
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
|
}
|
|
|
|
@property
|
|
def layers(self):
|
|
return self.model.layers
|