diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 0c0fc601..b4e7aab8 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -755,6 +755,26 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_gemma3_text(self): + from mlx_lm.models import gemma3_text + + args = gemma3_text.ModelArgs( + model_type="gemma3_text", + hidden_size=128, + num_hidden_layers=12, + intermediate_size=256, + num_attention_heads=4, + head_dim=32, + rms_norm_eps=1e-4, + num_key_value_heads=1, + sliding_window=1024, + sliding_window_pattern=6, + ) + model = gemma3_text.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_gpt_bigcode(self): from mlx_lm.models import gpt_bigcode