mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Add GPT-neox model (#863)
This commit is contained in:
@@ -22,7 +22,7 @@ class TestLora(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
def test_to_lora(self):
|
||||
def test_llama(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
@@ -60,6 +60,28 @@ class TestLora(unittest.TestCase):
|
||||
params["keys"] = ["self_attn.k_proj"]
|
||||
check_config(params)
|
||||
|
||||
def test_gpt_neox(self):
|
||||
from mlx_lm.models import gpt_neox
|
||||
|
||||
args = gpt_neox.ModelArgs(
|
||||
model_type="gpt_neox",
|
||||
max_position_embeddings=2048,
|
||||
hidden_size=6144,
|
||||
num_attention_heads=64,
|
||||
num_hidden_layers=44,
|
||||
layer_norm_eps=1e-5,
|
||||
vocab_size=50432,
|
||||
rotary_emb_base=10_000,
|
||||
rotary_pct=0.25,
|
||||
)
|
||||
|
||||
num_lora_layers = 4
|
||||
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
||||
|
||||
model = gpt_neox.Model(args)
|
||||
model.freeze()
|
||||
tuner.utils.linear_to_lora_layers(model, num_lora_layers, params)
|
||||
|
||||
|
||||
class TestScheduleConfig(unittest.TestCase):
|
||||
def test_join(self):
|
||||
|
||||
Reference in New Issue
Block a user