mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
GPT2 Support (#798)
* GPT-2 model support * Add test for gpt2 model * Fix weight sanitizing for quantization * use approx gelu --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -339,6 +339,22 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gpt2(self):
|
||||
from mlx_lm.models import gpt2
|
||||
|
||||
args = gpt2.ModelArgs(
|
||||
model_type="gpt2",
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_head=12,
|
||||
n_layer=12,
|
||||
n_positions=1024,
|
||||
layer_norm_epsilon=1e-5,
|
||||
vocab_size=50256,
|
||||
)
|
||||
model = gpt2.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
||||
|
||||
def test_openelm(self):
|
||||
from mlx_lm.models import openelm
|
||||
|
||||
|
Reference in New Issue
Block a user