mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Adds EXAONE architecture. (#1145)
* Adds EXAONE architecture. * nits + format * format * clean up and fix rope * clean up and fix rope --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2,7 +2,9 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models import rope_utils
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||
|
||||
|
||||
@@ -126,6 +128,26 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(cache.offset, 22)
|
||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||
|
||||
def test_rope(self):
|
||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||
|
||||
rope = rope_utils.initialize_rope(
|
||||
32,
|
||||
base=100,
|
||||
traditional=False,
|
||||
scaling_config={"rope_type": "linear", "factor": 10.0},
|
||||
)
|
||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||
|
||||
rope = rope_utils.initialize_rope(
|
||||
32,
|
||||
base=100,
|
||||
traditional=False,
|
||||
scaling_config={"rope_type": "llama3", "factor": 2.0},
|
||||
)
|
||||
self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE))
|
||||
|
||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||
|
||||
self.assertEqual(len(model.layers), num_layers)
|
||||
@@ -812,6 +834,23 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_exaone(self):
|
||||
from mlx_lm.models import exaone
|
||||
|
||||
args = exaone.ModelArgs(
|
||||
model_type="exaone",
|
||||
hidden_size=128,
|
||||
num_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=1000,
|
||||
layer_norm_epsilon=1e-4,
|
||||
rope_theta=10000,
|
||||
)
|
||||
model = exaone.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user