add tests

This commit is contained in:
Prince Canuma 2024-12-14 16:39:46 +01:00
parent 52595dafae
commit 2f443cc6d7

View File

@ -851,6 +851,19 @@ class TestModels(unittest.TestCase):
model = exaone.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers)
def test_cohere2(self):
from mlx_lm.models import cohere2
args = cohere2.ModelArgs(
model_type="cohere2",
hidden_size=4096,
head_dim=128,
num_hidden_layers=40,
sliding_window=4096,
sliding_window_pattern=4,
)
model = cohere2.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.num_hidden_layers)
if __name__ == "__main__":
unittest.main()