Enable unit testing in Circle and start some MLX LM tests (#545)

* add a few tests for mlx lm

* add a few tests for mlx lm

* add a few tests for mlx lm

* more tests / cleanup
This commit is contained in:
Awni Hannun 2024-03-07 09:31:57 -08:00 committed by GitHub
parent ef32379bc6
commit 7cdd1b69ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 294 additions and 20 deletions

View File

@ -17,6 +17,30 @@ jobs:
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
mlx_lm_build_and_test:
macos:
xcode: "15.2.0"
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.8
python3.8 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install unittest-xml-reporting
cd llms/
pip install -e .
- run:
name: Run Python tests
command: |
source env/bin/activate
python -m xmlrunner discover -v llms/tests -o test-results/
- store_test_results:
path: test-results
workflows:
build_and_test:
when:
@ -24,6 +48,7 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
jobs:
- mlx_lm_build_and_test
- linux_build_and_test
prb:

View File

@ -36,3 +36,12 @@ To determine the model layer names, we suggest either:
To add LoRA support edit
[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60)
Finally, add a test for the new modle type to the [model
tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py).
From the `llms/` directory, you can run the tests with:
```shell
python -m unittest discover tests/
```

View File

@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False

View File

@ -13,7 +13,6 @@ from .layers import RMSNorm
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int = 32000
max_position_embeddings: int = 4096 * 32
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
@ -38,7 +37,6 @@ class MixtralAttention(nn.Module):
self.num_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = args.num_key_value_heads
self.max_position_embeddings = args.max_position_embeddings
self.rope_theta = args.rope_theta
self.repeats = self.num_heads // self.num_key_value_heads

View File

@ -24,7 +24,6 @@ class ModelArgs(BaseModelArgs):
n_heads: int
vocab_size: int
embedding_size: int
model_type: str
rope_theta: float = 10000
rope_traditional: bool = False
mlp_ratio: int = 4

View File

@ -11,7 +11,7 @@ from .layers import LayerNorm
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
model_type: str = "phi"
max_position_embeddings: int = 2048
vocab_size: int = 51200
hidden_size: int = 2560

View File

@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
n_shared_head: int = (8,)
n_shared_head: int = 8
rope_theta: float = 10000
rope_traditional: bool = False
@ -80,16 +80,11 @@ class Attention(nn.Module):
bsz, q_len, self.v_num_heads, self.v_dim
).transpose(0, 2, 1, 3)
def _expand_kv(a: mx.array) -> mx.array:
a = mx.concatenate(
[mx.expand_dims(a, 1)] * self.config.n_shared_head, axis=1
)
return a.reshape([bsz, self.q_num_heads, q_len, -1])
# expand shared kv
assert self.k_num_heads == self.v_num_heads
key_states = _expand_kv(key_states)
value_states = _expand_kv(value_states)
repeats = self.config.n_shared_head
key_states = mx.repeat(key_states, repeats, axis=1)
value_states = mx.repeat(value_states, repeats, axis=1)
kv_seq_len = 0
if cache is not None:
@ -222,3 +217,7 @@ class Model(nn.Module):
) -> Tuple[mx.array, mx.array]:
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
@property
def layers(self):
return self.model.layers.layers

View File

@ -141,8 +141,7 @@ class QwenModel(nn.Module):
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
x = self.ln_f(x[:, T - 1 : T, :])
return x, cache
return self.ln_f(x), cache
class Model(nn.Module):
@ -162,3 +161,7 @@ class Model(nn.Module):
) -> Tuple[mx.array, mx.array]:
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
@property
def layers(self):
return self.transformer.h

View File

@ -11,7 +11,6 @@ from .layers import LayerNorm
@dataclass
class ModelArgs(BaseModelArgs):
max_position_embeddings: int
model_type: str
vocab_size: int
hidden_size: int

View File

@ -15,10 +15,8 @@ class ModelArgs(BaseModelArgs):
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int = None
max_position_embeddings: int = 16384
num_key_value_heads: int
norm_epsilon: float = 1e-5
norm_type: str = "layer_norm"
vocab_size: int = 49152
rope_theta: float = 100000
tie_word_embeddings: bool = True

199
llms/tests/test_models.py Normal file
View File

@ -0,0 +1,199 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_map
class TestModels(unittest.TestCase):
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)
self.assertEqual(model.model_type, model_type)
for t in [mx.float32, mx.float16]:
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
inputs = mx.array([[0, 1]])
outputs, cache = model(inputs)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
outputs, cache = model(mx.argmax(outputs[1, :], keepdims=True), cache=cache)
self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t)
def test_llama(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = llama.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phi2(self):
from mlx_lm.models import phi
args = phi.ModelArgs()
model = phi.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma(self):
from mlx_lm.models import gemma
args = gemma.ModelArgs(
model_type="gemma",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
head_dim=128,
rms_norm_eps=1e-5,
vocab_size=10_000,
num_key_value_heads=4,
)
model = gemma.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_mixtral(self):
from mlx_lm.models import mixtral
# Make a baby mixtral, because it will actually do the
# eval
args = mixtral.ModelArgs(
model_type="mixtral",
vocab_size=100,
hidden_size=32,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=4,
num_experts_per_tok=2,
num_key_value_heads=2,
num_local_experts=4,
)
model = mixtral.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
@unittest.skip("requires ai2-olmo")
def test_olmo(self):
from mlx_lm.models import olmo
args = olmo.ModelArgs(
model_type="olmo",
d_model=1024,
n_layers=4,
mlp_hidden_size=2048,
n_heads=2,
vocab_size=10_000,
embedding_size=10_000,
)
model = olmo.Model(args)
self.model_test_runner(
model,
args.model_type,
args.vocab_size,
args.n_layers,
)
def test_qwen2(self):
from mlx_lm.models import qwen2
args = qwen2.ModelArgs(
model_type="qwen2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = qwen2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_qwen(self):
from mlx_lm.models import qwen
args = qwen.ModelArgs(
model_type="qwen",
)
model = qwen.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_plamo(self):
from mlx_lm.models import plamo
args = plamo.ModelArgs(
model_type="plamo",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=8,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = plamo.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_stablelm(self):
from mlx_lm.models import stablelm
args = stablelm.ModelArgs(
model_type="stablelm",
vocab_size=10_000,
hidden_size=1024,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
partial_rotary_factor=1.0,
intermediate_size=2048,
layer_norm_eps=1e-2,
rope_theta=10_000,
use_qkv_bias=False,
)
model = stablelm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_starcoder2(self):
from mlx_lm.models import starcoder2
args = starcoder2.ModelArgs(
model_type="starcoder2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
)
model = starcoder2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__":
unittest.main()

45
llms/tests/test_utils.py Normal file
View File

@ -0,0 +1,45 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_flatten
from mlx_lm import utils
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
class TestUtils(unittest.TestCase):
def test_load(self):
model, _ = utils.load(HF_MODEL_PATH)
model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True)
mx.eval(model_lazy.parameters())
p1 = model.layers[0].mlp.up_proj.weight
p2 = model_lazy.layers[0].mlp.up_proj.weight
self.assertTrue(mx.allclose(p1, p2))
def test_make_shards(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=2048,
num_hidden_layers=32,
intermediate_size=4096,
num_attention_heads=32,
rms_norm_eps=1e-5,
vocab_size=30_000,
)
model = llama.Model(args)
weights = tree_flatten(model.parameters())
gb = sum(p.nbytes for _, p in weights) // 2**30
shards = utils.make_shards(dict(weights), 1)
self.assertTrue(gb <= len(shards) <= gb + 1)
if __name__ == "__main__":
unittest.main()