GPT2 Support (#798)

* GPT-2 model support

* Add test for gpt2 model

* Fix weight sanitizing for quantization

* use approx gelu

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Derek Lewis 2024-06-02 16:33:20 -07:00 committed by GitHub
parent c457a3f88b
commit 89b0b75250
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 225 additions and 0 deletions

207
llms/mlx_lm/models/gpt2.py Normal file
View File

@ -0,0 +1,207 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_additive_causal_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
n_ctx: int
n_embd: int
n_head: int
n_layer: int
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.n_head
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
self.n_embd = args.n_embd
self.n_head = args.n_head
self.head_dim = self.n_embd // self.n_head
self.scale = self.head_dim**-0.5
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.c_attn(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_embd = args.n_embd
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)
def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu_approx(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.n_head
self.n_embd = args.n_embd
self.layer_norm_epsilon = args.layer_norm_epsilon
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(
self.n_embd,
eps=self.layer_norm_epsilon,
)
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
return out
class GPT2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_embd = args.n_embd
self.n_positions = args.n_positions
self.vocab_size = args.vocab_size
self.n_layer = args.n_layer
self.layer_norm_epsilon = args.layer_norm_epsilon
assert self.vocab_size > 0
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None,
):
_, L = inputs.shape
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_additive_causal_mask(
hidden_states.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(hidden_states.dtype)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
return self.ln_f(hidden_states)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GPT2Model(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.wte.as_linear(out)
return out
def sanitize(self, weights):
new_weights = {}
for i in range(self.args.n_layer):
if f"h.{i}.attn.bias" in weights:
del weights[f"h.{i}.attn.bias"]
if f"h.{i}.attn.c_attn.weight" in weights:
weights[f"h.{i}.attn.c_attn.weight"] = weights[
f"h.{i}.attn.c_attn.weight"
].transpose(1, 0)
if f"h.{i}.attn.c_proj.weight" in weights:
weights[f"h.{i}.attn.c_proj.weight"] = weights[
f"h.{i}.attn.c_proj.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_fc.weight" in weights:
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
f"h.{i}.mlp.c_fc.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_proj.weight" in weights:
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
f"h.{i}.mlp.c_proj.weight"
].transpose(1, 0)
for weight in weights:
if not weight.startswith("model."):
new_weights[f"model.{weight}"] = weights[weight]
else:
new_weights[weight] = weights[weight]
return new_weights
@property
def layers(self):
return self.model.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -108,6 +108,8 @@ def linear_to_lora_layers(
elif model.model_type == "gpt_bigcode":
keys = set(["attn.c_attn"])
elif model.model_type == "gpt2":
keys = set(["attn.c_attn"])
elif model.model_type == "olmo":
keys = set(["att_proj"])
elif model.model_type == "openelm":

View File

@ -339,6 +339,22 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt2(self):
from mlx_lm.models import gpt2
args = gpt2.ModelArgs(
model_type="gpt2",
n_ctx=1024,
n_embd=768,
n_head=12,
n_layer=12,
n_positions=1024,
layer_norm_epsilon=1e-5,
vocab_size=50256,
)
model = gpt2.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
def test_openelm(self):
from mlx_lm.models import openelm