GPT2 Support (#798)

* GPT-2 model support

* Add test for gpt2 model

* Fix weight sanitizing for quantization

* use approx gelu

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Derek Lewis
2024-06-02 16:33:20 -07:00
committed by GitHub
parent c457a3f88b
commit 89b0b75250
3 changed files with 225 additions and 0 deletions

View File

@@ -339,6 +339,22 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt2(self):
from mlx_lm.models import gpt2
args = gpt2.ModelArgs(
model_type="gpt2",
n_ctx=1024,
n_embd=768,
n_head=12,
n_layer=12,
n_positions=1024,
layer_norm_epsilon=1e-5,
vocab_size=50256,
)
model = gpt2.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
def test_openelm(self):
from mlx_lm.models import openelm