Add support for Gemma3 (#1336)

* add support for gemma3

* fix model loading

* revert rmsnorm

* revert is sliding pattern

* revert

* add tests

* formatting

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix sliding window mask

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Prince Canuma
2025-03-13 16:14:25 +01:00
committed by GitHub
parent 3e5baf583b
commit 2fce02acd8
2 changed files with 258 additions and 0 deletions

View File

@@ -755,6 +755,26 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma3_text(self):
from mlx_lm.models import gemma3_text
args = gemma3_text.ModelArgs(
model_type="gemma3_text",
hidden_size=128,
num_hidden_layers=12,
intermediate_size=256,
num_attention_heads=4,
head_dim=32,
rms_norm_eps=1e-4,
num_key_value_heads=1,
sliding_window=1024,
sliding_window_pattern=6,
)
model = gemma3_text.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt_bigcode(self):
from mlx_lm.models import gpt_bigcode