mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add support for qwen2moe (#640)
* add sparsemoe block and update decoder logic * update file name to match HF * update name * Code formatting * update gates calculation * add support for Qwen2MoE. * fix pytest * code formatting and fix missing comma in utils * Remove decoder sparse step. Co-authored-by: bozheng-hit <dsoul0621@gmail.com> * remove gate layer anti-quantisation * remove unused argument --------- Co-authored-by: bozheng-hit <dsoul0621@gmail.com>
This commit is contained in:
@@ -114,6 +114,27 @@ class TestModels(unittest.TestCase):
|
||||
args.n_layers,
|
||||
)
|
||||
|
||||
def test_qwen2_moe(self):
|
||||
from mlx_lm.models import qwen2_moe
|
||||
|
||||
args = qwen2_moe.ModelArgs(
|
||||
model_type="qwen2_moe",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
num_experts_per_tok=4,
|
||||
num_experts=16,
|
||||
moe_intermediate_size=1024,
|
||||
shared_expert_intermediate_size=2048,
|
||||
)
|
||||
model = qwen2_moe.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_qwen2(self):
|
||||
from mlx_lm.models import qwen2
|
||||
|
||||
|
Reference in New Issue
Block a user