diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py new file mode 100644 index 00000000..9549f322 --- /dev/null +++ b/llms/mlx_lm/models/gpt_neox.py @@ -0,0 +1,227 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .base import BaseModelArgs, create_additive_causal_mask + +# Based on the transformers implementation at: +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + max_position_embeddings: int + hidden_size: int + num_attention_heads: int + num_hidden_layers: int + layer_norm_eps: float + vocab_size: int + rotary_emb_base: int + rotary_pct: float + num_key_value_heads: int = None + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + assert ( + args.hidden_size % args.num_attention_heads == 0 + ), "hidden_size must be divisible by num_attention_heads" + + self.hidden_size = args.hidden_size + self.num_attention_heads = args.num_attention_heads + self.head_dim = self.hidden_size // self.num_attention_heads + + self.rope = nn.RoPE( + dims=int(self.head_dim * args.rotary_pct), + traditional=False, + base=args.rotary_emb_base, + ) + + self.scale = self.head_dim**-0.5 + + self.query_key_value = nn.Linear( + self.hidden_size, 3 * self.hidden_size, bias=True + ) + self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + qkv = self.query_key_value(x) + + new_qkv_shape = qkv.shape[:-1] + (self.num_attention_heads, 3 * self.head_dim) + qkv = qkv.reshape(*new_qkv_shape) + + queries, keys, values = [x.transpose(0, 2, 1, 3) for x in qkv.split(3, -1)] + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.dense(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.hidden_size = args.hidden_size + self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size) + self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size) + + def __call__(self, x) -> mx.array: + # gelu_approx corresponds to FastGELUActivation in transformers. + return self.dense_4h_to_h(nn.gelu_approx(self.dense_h_to_4h(x))) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.hidden_size = args.hidden_size + self.layer_norm_eps = args.layer_norm_eps + self.attention = Attention(args) + self.mlp = MLP(args) + self.input_layernorm = nn.LayerNorm( + self.hidden_size, + eps=self.layer_norm_eps, + ) + self.post_attention_layernorm = nn.LayerNorm( + self.hidden_size, eps=self.layer_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + residual = x + # NeoX runs attention and feedforward network in parallel. + attn = self.attention(self.input_layernorm(x), mask, cache) + ffn = self.mlp(self.post_attention_layernorm(x)) + out = attn + ffn + residual + return out + + +class GPTNeoXModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.hidden_size = args.hidden_size + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + self.layer_norm_eps = args.layer_norm_eps + assert self.vocab_size > 0 + self.embed_in = nn.Embedding(self.vocab_size, self.hidden_size) + self.embed_out = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.h = [TransformerBlock(args=args) for _ in range(self.num_hidden_layers)] + self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + _, L = inputs.shape + + hidden_states = self.embed_in(inputs) + + mask = None + if hidden_states.shape[1] > 1: + mask = create_additive_causal_mask( + hidden_states.shape[1], cache[0].offset if cache is not None else 0 + ) + mask = mask.astype(hidden_states.dtype) + + if cache is None: + cache = [None] * len(self.h) + + for layer, c in zip(self.h, cache): + hidden_states = layer(hidden_states, mask, cache=c) + + out = self.final_layer_norm(hidden_states) + out = self.embed_out(out) + + return out + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = GPTNeoXModel(args) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + return out + + def sanitize(self, weights): + new_weights = {} + + for w_key, w_value in weights.items(): + # Created through register_buffer in Pytorch, not needed here. + ignore_suffixes = [ + ".attention.bias", + ".attention.masked_bias", + ".attention.rotary_emb.inv_freq", + ] + + skip_weight = False + for ignored_suffix in ignore_suffixes: + if w_key.endswith(ignored_suffix): + skip_weight = True + break + + if skip_weight: + continue + + if not w_key.startswith("model."): + w_key = f"model.{w_key}" + + w_key = w_key.replace(".gpt_neox.layers.", ".h.") + w_key = w_key.replace(".gpt_neox.", ".") + + new_weights[w_key] = w_value + + return new_weights + + @property + def layers(self): + return self.model.h + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index fe9740f5..2c97228d 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -111,6 +111,8 @@ def linear_to_lora_layers( keys = set(["attn.c_attn"]) elif model.model_type == "gpt2": keys = set(["attn.c_attn"]) + elif model.model_type == "gpt_neox": + keys = set(["attention.query_key_value"]) elif model.model_type == "olmo": keys = set(["att_proj"]) elif model.model_type == "openelm": diff --git a/llms/tests/test_lora.py b/llms/tests/test_lora.py index 1cddaf3a..f37ae3c2 100644 --- a/llms/tests/test_lora.py +++ b/llms/tests/test_lora.py @@ -22,7 +22,7 @@ class TestLora(unittest.TestCase): def tearDown(self): sys.stdout = sys.__stdout__ - def test_to_lora(self): + def test_llama(self): from mlx_lm.models import llama args = llama.ModelArgs( @@ -60,6 +60,28 @@ class TestLora(unittest.TestCase): params["keys"] = ["self_attn.k_proj"] check_config(params) + def test_gpt_neox(self): + from mlx_lm.models import gpt_neox + + args = gpt_neox.ModelArgs( + model_type="gpt_neox", + max_position_embeddings=2048, + hidden_size=6144, + num_attention_heads=64, + num_hidden_layers=44, + layer_norm_eps=1e-5, + vocab_size=50432, + rotary_emb_base=10_000, + rotary_pct=0.25, + ) + + num_lora_layers = 4 + params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0} + + model = gpt_neox.Model(args) + model.freeze() + tuner.utils.linear_to_lora_layers(model, num_lora_layers, params) + class TestScheduleConfig(unittest.TestCase): def test_join(self): diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index a3c0b9c4..5149a0f1 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -355,6 +355,25 @@ class TestModels(unittest.TestCase): model = gpt2.Model(args) self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer) + def test_gpt_neox(self): + from mlx_lm.models import gpt_neox + + args = gpt_neox.ModelArgs( + model_type="gpt_neox", + max_position_embeddings=2048, + hidden_size=6144, + num_attention_heads=64, + num_hidden_layers=44, + layer_norm_eps=1e-5, + vocab_size=50432, + rotary_emb_base=10_000, + rotary_pct=0.25, + ) + model = gpt_neox.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_openelm(self): from mlx_lm.models import openelm