mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-22 13:07:55 +08:00
118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
![]() |
"""Tests for DeciLM implementation."""
|
||
|
|
||
|
import unittest
|
||
|
import mlx.core as mx
|
||
|
import mlx.nn as nn
|
||
|
|
||
|
import sys
|
||
|
sys.path.append('..')
|
||
|
|
||
|
from decilm import DeciLMArgs, DummyAttention, DummyFFN, VariableAttention, DeciLMBlock
|
||
|
|
||
|
|
||
|
class TestDeciLMComponents(unittest.TestCase):
|
||
|
"""Test DeciLM model components."""
|
||
|
|
||
|
def setUp(self):
|
||
|
"""Set up test fixtures."""
|
||
|
self.args = DeciLMArgs(
|
||
|
hidden_size=4096,
|
||
|
num_attention_heads=32,
|
||
|
num_key_value_heads=8,
|
||
|
intermediate_size=11008,
|
||
|
)
|
||
|
|
||
|
def test_dummy_attention(self):
|
||
|
"""Test dummy attention passthrough."""
|
||
|
dummy_attn = DummyAttention(self.args)
|
||
|
|
||
|
# Test input
|
||
|
x = mx.random.normal((2, 10, 4096))
|
||
|
|
||
|
# Should return input unchanged
|
||
|
output = dummy_attn(x)
|
||
|
self.assertTrue(mx.array_equal(output, x))
|
||
|
|
||
|
def test_dummy_ffn(self):
|
||
|
"""Test dummy FFN passthrough."""
|
||
|
dummy_ffn = DummyFFN(self.args)
|
||
|
|
||
|
# Test input
|
||
|
x = mx.random.normal((2, 10, 4096))
|
||
|
|
||
|
# Should return input unchanged
|
||
|
output = dummy_ffn(x)
|
||
|
self.assertTrue(mx.array_equal(output, x))
|
||
|
|
||
|
def test_variable_attention(self):
|
||
|
"""Test variable attention with different KV heads."""
|
||
|
# Test with 4 KV heads (less than Q heads)
|
||
|
var_attn = VariableAttention(self.args, n_kv_heads=4)
|
||
|
|
||
|
x = mx.random.normal((2, 10, 4096))
|
||
|
output = var_attn(x)
|
||
|
|
||
|
# Output shape should match input
|
||
|
self.assertEqual(output.shape, (2, 10, 4096))
|
||
|
|
||
|
def test_decilm_block_dummy(self):
|
||
|
"""Test DeciLM block with dummy components."""
|
||
|
# Config with dummy attention and FFN
|
||
|
block_config = {
|
||
|
"attention": {"no_op": True},
|
||
|
"ffn": {"no_op": True}
|
||
|
}
|
||
|
|
||
|
block = DeciLMBlock(self.args, block_config)
|
||
|
x = mx.random.normal((2, 10, 4096))
|
||
|
|
||
|
output = block(x)
|
||
|
|
||
|
# With both dummy, output should be close to input
|
||
|
# (only layer norms applied)
|
||
|
self.assertEqual(output.shape, x.shape)
|
||
|
|
||
|
def test_decilm_block_mixed(self):
|
||
|
"""Test DeciLM block with mixed dummy/active components."""
|
||
|
# Config with active attention but dummy FFN
|
||
|
block_config = {
|
||
|
"attention": {"no_op": False, "n_heads_in_group": 8},
|
||
|
"ffn": {"no_op": True}
|
||
|
}
|
||
|
|
||
|
block = DeciLMBlock(self.args, block_config)
|
||
|
x = mx.random.normal((2, 10, 4096))
|
||
|
|
||
|
output = block(x)
|
||
|
self.assertEqual(output.shape, x.shape)
|
||
|
|
||
|
def test_block_config_variations(self):
|
||
|
"""Test various block configurations."""
|
||
|
configs = [
|
||
|
# Standard block
|
||
|
{
|
||
|
"attention": {"no_op": False, "n_heads_in_group": 8},
|
||
|
"ffn": {"no_op": False, "ffn_mult": 2.5}
|
||
|
},
|
||
|
# Variable FFN multiplier
|
||
|
{
|
||
|
"attention": {"no_op": False, "n_heads_in_group": 8},
|
||
|
"ffn": {"no_op": False, "ffn_mult": 1.5}
|
||
|
},
|
||
|
# Different KV heads
|
||
|
{
|
||
|
"attention": {"no_op": False, "n_heads_in_group": 4},
|
||
|
"ffn": {"no_op": False, "ffn_mult": 2.5}
|
||
|
},
|
||
|
]
|
||
|
|
||
|
x = mx.random.normal((1, 5, 4096))
|
||
|
|
||
|
for config in configs:
|
||
|
block = DeciLMBlock(self.args, config)
|
||
|
output = block(x)
|
||
|
self.assertEqual(output.shape, x.shape)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|