mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
Add support for Gemma3 (#1336)
* add support for gemma3 * fix model loading * revert rmsnorm * revert is sliding pattern * revert * add tests * formatting * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix sliding window mask --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -755,6 +755,26 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma3_text(self):
|
||||
from mlx_lm.models import gemma3_text
|
||||
|
||||
args = gemma3_text.ModelArgs(
|
||||
model_type="gemma3_text",
|
||||
hidden_size=128,
|
||||
num_hidden_layers=12,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
head_dim=32,
|
||||
rms_norm_eps=1e-4,
|
||||
num_key_value_heads=1,
|
||||
sliding_window=1024,
|
||||
sliding_window_pattern=6,
|
||||
)
|
||||
model = gemma3_text.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gpt_bigcode(self):
|
||||
from mlx_lm.models import gpt_bigcode
|
||||
|
||||
|
Reference in New Issue
Block a user