From 7cdd1b69ac0a26057ff45d0da1496fe7c6bbfb34 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 7 Mar 2024 09:31:57 -0800 Subject: [PATCH] Enable unit testing in Circle and start some MLX LM tests (#545) * add a few tests for mlx lm * add a few tests for mlx lm * add a few tests for mlx lm * more tests / cleanup --- .circleci/config.yml | 25 ++++ llms/CONTRIBUTING.md | 9 ++ llms/mlx_lm/models/gemma.py | 2 +- llms/mlx_lm/models/mixtral.py | 2 - llms/mlx_lm/models/olmo.py | 1 - llms/mlx_lm/models/phi.py | 2 +- llms/mlx_lm/models/plamo.py | 17 ++- llms/mlx_lm/models/qwen.py | 7 +- llms/mlx_lm/models/stablelm.py | 1 - llms/mlx_lm/models/starcoder2.py | 4 +- llms/tests/test_models.py | 199 +++++++++++++++++++++++++++++++ llms/tests/test_utils.py | 45 +++++++ 12 files changed, 294 insertions(+), 20 deletions(-) create mode 100644 llms/tests/test_models.py create mode 100644 llms/tests/test_utils.py diff --git a/.circleci/config.yml b/.circleci/config.yml index aec28e77..2ecda2dc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -17,6 +17,30 @@ jobs: pre-commit run --all if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi + mlx_lm_build_and_test: + macos: + xcode: "15.2.0" + resource_class: macos.m1.large.gen1 + steps: + - checkout + - run: + name: Install dependencies + command: | + brew install python@3.8 + python3.8 -m venv env + source env/bin/activate + pip install --upgrade pip + pip install unittest-xml-reporting + cd llms/ + pip install -e . + - run: + name: Run Python tests + command: | + source env/bin/activate + python -m xmlrunner discover -v llms/tests -o test-results/ + - store_test_results: + path: test-results + workflows: build_and_test: when: @@ -24,6 +48,7 @@ workflows: pattern: "^(?!pull/)[-\\w]+$" value: << pipeline.git.branch >> jobs: + - mlx_lm_build_and_test - linux_build_and_test prb: diff --git a/llms/CONTRIBUTING.md b/llms/CONTRIBUTING.md index e3590f4e..d85067cc 100644 --- a/llms/CONTRIBUTING.md +++ b/llms/CONTRIBUTING.md @@ -36,3 +36,12 @@ To determine the model layer names, we suggest either: To add LoRA support edit [`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60) + +Finally, add a test for the new modle type to the [model +tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py). + +From the `llms/` directory, you can run the tests with: + +```shell +python -m unittest discover tests/ +``` diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 2bc782b7..2a99c3c9 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs): head_dim: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: int rope_theta: float = 10000 rope_traditional: bool = False diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index c2ddcb7c..027fcd7c 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -13,7 +13,6 @@ from .layers import RMSNorm class ModelArgs(BaseModelArgs): model_type: str vocab_size: int = 32000 - max_position_embeddings: int = 4096 * 32 hidden_size: int = 4096 intermediate_size: int = 14336 num_hidden_layers: int = 32 @@ -38,7 +37,6 @@ class MixtralAttention(nn.Module): self.num_heads = args.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = args.num_key_value_heads - self.max_position_embeddings = args.max_position_embeddings self.rope_theta = args.rope_theta self.repeats = self.num_heads // self.num_key_value_heads diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index f97ce6f9..548457d6 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -24,7 +24,6 @@ class ModelArgs(BaseModelArgs): n_heads: int vocab_size: int embedding_size: int - model_type: str rope_theta: float = 10000 rope_traditional: bool = False mlp_ratio: int = 4 diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 85d16759..d8ef54f4 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -11,7 +11,7 @@ from .layers import LayerNorm @dataclass class ModelArgs(BaseModelArgs): - model_type: str + model_type: str = "phi" max_position_embeddings: int = 2048 vocab_size: int = 51200 hidden_size: int = 2560 diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index ba026335..c0a32648 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - n_shared_head: int = (8,) + n_shared_head: int = 8 rope_theta: float = 10000 rope_traditional: bool = False @@ -80,16 +80,11 @@ class Attention(nn.Module): bsz, q_len, self.v_num_heads, self.v_dim ).transpose(0, 2, 1, 3) - def _expand_kv(a: mx.array) -> mx.array: - a = mx.concatenate( - [mx.expand_dims(a, 1)] * self.config.n_shared_head, axis=1 - ) - return a.reshape([bsz, self.q_num_heads, q_len, -1]) - # expand shared kv assert self.k_num_heads == self.v_num_heads - key_states = _expand_kv(key_states) - value_states = _expand_kv(value_states) + repeats = self.config.n_shared_head + key_states = mx.repeat(key_states, repeats, axis=1) + value_states = mx.repeat(value_states, repeats, axis=1) kv_seq_len = 0 if cache is not None: @@ -222,3 +217,7 @@ class Model(nn.Module): ) -> Tuple[mx.array, mx.array]: out, cache = self.model(inputs, cache) return self.lm_head(out), cache + + @property + def layers(self): + return self.model.layers.layers diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 16609414..5fe02e98 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -141,8 +141,7 @@ class QwenModel(nn.Module): for e, layer in enumerate(self.h): x, cache[e] = layer(x, mask, cache[e]) - x = self.ln_f(x[:, T - 1 : T, :]) - return x, cache + return self.ln_f(x), cache class Model(nn.Module): @@ -162,3 +161,7 @@ class Model(nn.Module): ) -> Tuple[mx.array, mx.array]: y, cache = self.transformer(x, mask, cache) return self.lm_head(y), cache + + @property + def layers(self): + return self.transformer.h diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 5fbca3ae..f03051bc 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -11,7 +11,6 @@ from .layers import LayerNorm @dataclass class ModelArgs(BaseModelArgs): - max_position_embeddings: int model_type: str vocab_size: int hidden_size: int diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 7a431800..582bdbcb 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -15,10 +15,8 @@ class ModelArgs(BaseModelArgs): num_hidden_layers: int intermediate_size: int num_attention_heads: int - num_key_value_heads: int = None - max_position_embeddings: int = 16384 + num_key_value_heads: int norm_epsilon: float = 1e-5 - norm_type: str = "layer_norm" vocab_size: int = 49152 rope_theta: float = 100000 tie_word_embeddings: bool = True diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py new file mode 100644 index 00000000..7dcc82b2 --- /dev/null +++ b/llms/tests/test_models.py @@ -0,0 +1,199 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx +from mlx.utils import tree_map + + +class TestModels(unittest.TestCase): + + def model_test_runner(self, model, model_type, vocab_size, num_layers): + + self.assertEqual(len(model.layers), num_layers) + self.assertEqual(model.model_type, model_type) + + for t in [mx.float32, mx.float16]: + model.update(tree_map(lambda p: p.astype(t), model.parameters())) + + inputs = mx.array([[0, 1]]) + outputs, cache = model(inputs) + self.assertEqual(outputs.shape, (1, 2, vocab_size)) + self.assertEqual(outputs.dtype, t) + + outputs, cache = model(mx.argmax(outputs[1, :], keepdims=True), cache=cache) + self.assertEqual(outputs.shape, (1, 1, vocab_size)) + self.assertEqual(outputs.dtype, t) + + def test_llama(self): + from mlx_lm.models import llama + + args = llama.ModelArgs( + model_type="llama", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = llama.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_phi2(self): + from mlx_lm.models import phi + + args = phi.ModelArgs() + model = phi.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_gemma(self): + from mlx_lm.models import gemma + + args = gemma.ModelArgs( + model_type="gemma", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + head_dim=128, + rms_norm_eps=1e-5, + vocab_size=10_000, + num_key_value_heads=4, + ) + model = gemma.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_mixtral(self): + from mlx_lm.models import mixtral + + # Make a baby mixtral, because it will actually do the + # eval + args = mixtral.ModelArgs( + model_type="mixtral", + vocab_size=100, + hidden_size=32, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_experts_per_tok=2, + num_key_value_heads=2, + num_local_experts=4, + ) + model = mixtral.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + @unittest.skip("requires ai2-olmo") + def test_olmo(self): + from mlx_lm.models import olmo + + args = olmo.ModelArgs( + model_type="olmo", + d_model=1024, + n_layers=4, + mlp_hidden_size=2048, + n_heads=2, + vocab_size=10_000, + embedding_size=10_000, + ) + model = olmo.Model(args) + self.model_test_runner( + model, + args.model_type, + args.vocab_size, + args.n_layers, + ) + + def test_qwen2(self): + from mlx_lm.models import qwen2 + + args = qwen2.ModelArgs( + model_type="qwen2", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = qwen2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_qwen(self): + from mlx_lm.models import qwen + + args = qwen.ModelArgs( + model_type="qwen", + ) + model = qwen.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_plamo(self): + from mlx_lm.models import plamo + + args = plamo.ModelArgs( + model_type="plamo", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=8, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = plamo.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_stablelm(self): + from mlx_lm.models import stablelm + + args = stablelm.ModelArgs( + model_type="stablelm", + vocab_size=10_000, + hidden_size=1024, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + partial_rotary_factor=1.0, + intermediate_size=2048, + layer_norm_eps=1e-2, + rope_theta=10_000, + use_qkv_bias=False, + ) + model = stablelm.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_starcoder2(self): + from mlx_lm.models import starcoder2 + + args = starcoder2.ModelArgs( + model_type="starcoder2", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + num_key_value_heads=4, + ) + model = starcoder2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/llms/tests/test_utils.py b/llms/tests/test_utils.py new file mode 100644 index 00000000..f68bb13c --- /dev/null +++ b/llms/tests/test_utils.py @@ -0,0 +1,45 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx +from mlx.utils import tree_flatten +from mlx_lm import utils + +HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + + +class TestUtils(unittest.TestCase): + + def test_load(self): + model, _ = utils.load(HF_MODEL_PATH) + + model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True) + + mx.eval(model_lazy.parameters()) + + p1 = model.layers[0].mlp.up_proj.weight + p2 = model_lazy.layers[0].mlp.up_proj.weight + self.assertTrue(mx.allclose(p1, p2)) + + def test_make_shards(self): + from mlx_lm.models import llama + + args = llama.ModelArgs( + model_type="llama", + hidden_size=2048, + num_hidden_layers=32, + intermediate_size=4096, + num_attention_heads=32, + rms_norm_eps=1e-5, + vocab_size=30_000, + ) + model = llama.Model(args) + weights = tree_flatten(model.parameters()) + gb = sum(p.nbytes for _, p in weights) // 2**30 + shards = utils.make_shards(dict(weights), 1) + self.assertTrue(gb <= len(shards) <= gb + 1) + + +if __name__ == "__main__": + unittest.main()