mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add GPT-neox model (#863)
This commit is contained in:
@@ -22,7 +22,7 @@ class TestLora(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
def test_to_lora(self):
|
||||
def test_llama(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
@@ -60,6 +60,28 @@ class TestLora(unittest.TestCase):
|
||||
params["keys"] = ["self_attn.k_proj"]
|
||||
check_config(params)
|
||||
|
||||
def test_gpt_neox(self):
|
||||
from mlx_lm.models import gpt_neox
|
||||
|
||||
args = gpt_neox.ModelArgs(
|
||||
model_type="gpt_neox",
|
||||
max_position_embeddings=2048,
|
||||
hidden_size=6144,
|
||||
num_attention_heads=64,
|
||||
num_hidden_layers=44,
|
||||
layer_norm_eps=1e-5,
|
||||
vocab_size=50432,
|
||||
rotary_emb_base=10_000,
|
||||
rotary_pct=0.25,
|
||||
)
|
||||
|
||||
num_lora_layers = 4
|
||||
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
||||
|
||||
model = gpt_neox.Model(args)
|
||||
model.freeze()
|
||||
tuner.utils.linear_to_lora_layers(model, num_lora_layers, params)
|
||||
|
||||
|
||||
class TestScheduleConfig(unittest.TestCase):
|
||||
def test_join(self):
|
||||
|
@@ -355,6 +355,25 @@ class TestModels(unittest.TestCase):
|
||||
model = gpt2.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
||||
|
||||
def test_gpt_neox(self):
|
||||
from mlx_lm.models import gpt_neox
|
||||
|
||||
args = gpt_neox.ModelArgs(
|
||||
model_type="gpt_neox",
|
||||
max_position_embeddings=2048,
|
||||
hidden_size=6144,
|
||||
num_attention_heads=64,
|
||||
num_hidden_layers=44,
|
||||
layer_norm_eps=1e-5,
|
||||
vocab_size=50432,
|
||||
rotary_emb_base=10_000,
|
||||
rotary_pct=0.25,
|
||||
)
|
||||
model = gpt_neox.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_openelm(self):
|
||||
from mlx_lm.models import openelm
|
||||
|
||||
|
Reference in New Issue
Block a user