mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Add support for Cohere's Command-R (#565)
* initial commit for command-R * update mlp, layernorm, lm_head and model args * add custom layernorm * add default to tie_word_embeddings * add layernorm weight type and refactor * update layernorm (bias conditional) in model/layers * fix layer norm use traditional rope * add test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -254,7 +254,6 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
||||
|
||||
def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
|
||||
|
||||
from mlx_lm.models import starcoder2
|
||||
|
||||
args = starcoder2.ModelArgs(
|
||||
@@ -276,6 +275,17 @@ class TestModels(unittest.TestCase):
|
||||
self.assertIn("lm_head.weight", sanitized_weights)
|
||||
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
||||
|
||||
def test_cohere(self):
|
||||
from mlx_lm.models import cohere
|
||||
|
||||
args = cohere.ModelArgs(
|
||||
model_type="cohere",
|
||||
)
|
||||
model = cohere.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user