mlx-examples/llms/mlx_lm/models/gpt_neox.py

217 lines
6.3 KiB
Python
Raw Normal View History

# Copyright © 2023-2024 Apple Inc.
2024-07-11 21:13:17 +08:00
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
2024-07-11 21:13:17 +08:00
import mlx.core as mx
import mlx.nn as nn
import numpy as np
Unify attention mask in LLMs (#911) * Unify attention mask creation in LLMs. Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc code to create a mask for the attention mechanism. This usually takes the form: ``` mask = None if h.shape[1] > 1: mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) ``` This correctly creates a mask only if the input consists of more than one token. But this code assumes the multi-token input is at the beginning of inference. If, for example, we are evaluating multiple tokens because of speculative decoding or prompt cache reuse, this mask will not have the correct shape and and will cause the raising of an exception in the attention computation. Some of the models correctly implement the mask creation with code like this: ``` mask = None if h.shape[1] > 1: mask = create_additive_causal_mask( h.shape[1], cache[0].offset if cache is not None else 0 ) mask = mask.astype(h.dtype) ``` This commit unifies the attention mask creation for all models with a new function `create_attention_mask`, reducing code duplication and helping all models support inference performance enhancements like those mentioned above. * Allow batches in LLM key-value cache The current implementation of the LLM key-value cache assumes that the input batch is of size 1. Input batching (evaluating multiple alterative inputs at the same time) can be a valuable tool for speculative sampling and other techniques. This change removes the hard-coded batch size from the code that resizes the key-value cache. * Simplify causal mask creation Use the same codepath regardless of whether there's an offset or not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717). * Use old-style type annotation to avoid linter error
2024-07-26 07:45:22 +08:00
from .base import BaseModelArgs, create_attention_mask
2024-07-11 21:13:17 +08:00
# Based on the transformers implementation at:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
max_position_embeddings: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
layer_norm_eps: float
vocab_size: int
rotary_emb_base: int
rotary_pct: float
num_key_value_heads: int = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert (
args.hidden_size % args.num_attention_heads == 0
), "hidden_size must be divisible by num_attention_heads"
self.hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.rope = nn.RoPE(
dims=int(self.head_dim * args.rotary_pct),
traditional=False,
base=args.rotary_emb_base,
)
self.scale = self.head_dim**-0.5
self.query_key_value = nn.Linear(
self.hidden_size, 3 * self.hidden_size, bias=True
)
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
2024-07-11 21:13:17 +08:00
) -> mx.array:
B, L, D = x.shape
qkv = self.query_key_value(x)
new_qkv_shape = qkv.shape[:-1] + (self.num_attention_heads, 3 * self.head_dim)
qkv = qkv.reshape(*new_qkv_shape)
queries, keys, values = [x.transpose(0, 2, 1, 3) for x in qkv.split(3, -1)]
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
def __call__(self, x) -> mx.array:
# gelu_approx corresponds to FastGELUActivation in transformers.
return self.dense_4h_to_h(nn.gelu_approx(self.dense_h_to_4h(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.layer_norm_eps = args.layer_norm_eps
self.attention = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.LayerNorm(
self.hidden_size,
eps=self.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm(
self.hidden_size, eps=self.layer_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
2024-07-11 21:13:17 +08:00
) -> mx.array:
residual = x
# NeoX runs attention and feedforward network in parallel.
attn = self.attention(self.input_layernorm(x), mask, cache)
ffn = self.mlp(self.post_attention_layernorm(x))
out = attn + ffn + residual
return out
class GPTNeoXModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.layer_norm_eps = args.layer_norm_eps
assert self.vocab_size > 0
self.embed_in = nn.Embedding(self.vocab_size, self.hidden_size)
self.embed_out = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.h = [TransformerBlock(args=args) for _ in range(self.num_hidden_layers)]
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
_, L = inputs.shape
hidden_states = self.embed_in(inputs)
Unify attention mask in LLMs (#911) * Unify attention mask creation in LLMs. Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc code to create a mask for the attention mechanism. This usually takes the form: ``` mask = None if h.shape[1] > 1: mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) ``` This correctly creates a mask only if the input consists of more than one token. But this code assumes the multi-token input is at the beginning of inference. If, for example, we are evaluating multiple tokens because of speculative decoding or prompt cache reuse, this mask will not have the correct shape and and will cause the raising of an exception in the attention computation. Some of the models correctly implement the mask creation with code like this: ``` mask = None if h.shape[1] > 1: mask = create_additive_causal_mask( h.shape[1], cache[0].offset if cache is not None else 0 ) mask = mask.astype(h.dtype) ``` This commit unifies the attention mask creation for all models with a new function `create_attention_mask`, reducing code duplication and helping all models support inference performance enhancements like those mentioned above. * Allow batches in LLM key-value cache The current implementation of the LLM key-value cache assumes that the input batch is of size 1. Input batching (evaluating multiple alterative inputs at the same time) can be a valuable tool for speculative sampling and other techniques. This change removes the hard-coded batch size from the code that resizes the key-value cache. * Simplify causal mask creation Use the same codepath regardless of whether there's an offset or not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717). * Use old-style type annotation to avoid linter error
2024-07-26 07:45:22 +08:00
mask = create_attention_mask(hidden_states, cache)
2024-07-11 21:13:17 +08:00
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
out = self.final_layer_norm(hidden_states)
out = self.embed_out(out)
return out
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GPTNeoXModel(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return out
def sanitize(self, weights):
new_weights = {}
for w_key, w_value in weights.items():
# Created through register_buffer in Pytorch, not needed here.
ignore_suffixes = [
".attention.bias",
".attention.masked_bias",
".attention.rotary_emb.inv_freq",
]
skip_weight = False
for ignored_suffix in ignore_suffixes:
if w_key.endswith(ignored_suffix):
skip_weight = True
break
if skip_weight:
continue
if not w_key.startswith("model."):
w_key = f"model.{w_key}"
w_key = w_key.replace(".gpt_neox.layers.", ".h.")
w_key = w_key.replace(".gpt_neox.", ".")
new_weights[w_key] = w_value
return new_weights
@property
def layers(self):
return self.model.h