mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Enable unit testing in Circle and start some MLX LM tests (#545)
* add a few tests for mlx lm * add a few tests for mlx lm * add a few tests for mlx lm * more tests / cleanup
This commit is contained in:
parent
ef32379bc6
commit
7cdd1b69ac
@ -17,6 +17,30 @@ jobs:
|
|||||||
pre-commit run --all
|
pre-commit run --all
|
||||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||||
|
|
||||||
|
mlx_lm_build_and_test:
|
||||||
|
macos:
|
||||||
|
xcode: "15.2.0"
|
||||||
|
resource_class: macos.m1.large.gen1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install dependencies
|
||||||
|
command: |
|
||||||
|
brew install python@3.8
|
||||||
|
python3.8 -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install unittest-xml-reporting
|
||||||
|
cd llms/
|
||||||
|
pip install -e .
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python -m xmlrunner discover -v llms/tests -o test-results/
|
||||||
|
- store_test_results:
|
||||||
|
path: test-results
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
@ -24,6 +48,7 @@ workflows:
|
|||||||
pattern: "^(?!pull/)[-\\w]+$"
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
value: << pipeline.git.branch >>
|
value: << pipeline.git.branch >>
|
||||||
jobs:
|
jobs:
|
||||||
|
- mlx_lm_build_and_test
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
|
|
||||||
prb:
|
prb:
|
||||||
|
@ -36,3 +36,12 @@ To determine the model layer names, we suggest either:
|
|||||||
|
|
||||||
To add LoRA support edit
|
To add LoRA support edit
|
||||||
[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60)
|
[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60)
|
||||||
|
|
||||||
|
Finally, add a test for the new modle type to the [model
|
||||||
|
tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py).
|
||||||
|
|
||||||
|
From the `llms/` directory, you can run the tests with:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m unittest discover tests/
|
||||||
|
```
|
||||||
|
@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
head_dim: int
|
head_dim: int
|
||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: int
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@ from .layers import RMSNorm
|
|||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
model_type: str
|
model_type: str
|
||||||
vocab_size: int = 32000
|
vocab_size: int = 32000
|
||||||
max_position_embeddings: int = 4096 * 32
|
|
||||||
hidden_size: int = 4096
|
hidden_size: int = 4096
|
||||||
intermediate_size: int = 14336
|
intermediate_size: int = 14336
|
||||||
num_hidden_layers: int = 32
|
num_hidden_layers: int = 32
|
||||||
@ -38,7 +37,6 @@ class MixtralAttention(nn.Module):
|
|||||||
self.num_heads = args.num_attention_heads
|
self.num_heads = args.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
self.num_key_value_heads = args.num_key_value_heads
|
self.num_key_value_heads = args.num_key_value_heads
|
||||||
self.max_position_embeddings = args.max_position_embeddings
|
|
||||||
self.rope_theta = args.rope_theta
|
self.rope_theta = args.rope_theta
|
||||||
|
|
||||||
self.repeats = self.num_heads // self.num_key_value_heads
|
self.repeats = self.num_heads // self.num_key_value_heads
|
||||||
|
@ -24,7 +24,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
n_heads: int
|
n_heads: int
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
embedding_size: int
|
embedding_size: int
|
||||||
model_type: str
|
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
mlp_ratio: int = 4
|
mlp_ratio: int = 4
|
||||||
|
@ -11,7 +11,7 @@ from .layers import LayerNorm
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
model_type: str
|
model_type: str = "phi"
|
||||||
max_position_embeddings: int = 2048
|
max_position_embeddings: int = 2048
|
||||||
vocab_size: int = 51200
|
vocab_size: int = 51200
|
||||||
hidden_size: int = 2560
|
hidden_size: int = 2560
|
||||||
|
@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
n_shared_head: int = (8,)
|
n_shared_head: int = 8
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
|
|
||||||
@ -80,16 +80,11 @@ class Attention(nn.Module):
|
|||||||
bsz, q_len, self.v_num_heads, self.v_dim
|
bsz, q_len, self.v_num_heads, self.v_dim
|
||||||
).transpose(0, 2, 1, 3)
|
).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
def _expand_kv(a: mx.array) -> mx.array:
|
|
||||||
a = mx.concatenate(
|
|
||||||
[mx.expand_dims(a, 1)] * self.config.n_shared_head, axis=1
|
|
||||||
)
|
|
||||||
return a.reshape([bsz, self.q_num_heads, q_len, -1])
|
|
||||||
|
|
||||||
# expand shared kv
|
# expand shared kv
|
||||||
assert self.k_num_heads == self.v_num_heads
|
assert self.k_num_heads == self.v_num_heads
|
||||||
key_states = _expand_kv(key_states)
|
repeats = self.config.n_shared_head
|
||||||
value_states = _expand_kv(value_states)
|
key_states = mx.repeat(key_states, repeats, axis=1)
|
||||||
|
value_states = mx.repeat(value_states, repeats, axis=1)
|
||||||
|
|
||||||
kv_seq_len = 0
|
kv_seq_len = 0
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
@ -222,3 +217,7 @@ class Model(nn.Module):
|
|||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out), cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers.layers
|
||||||
|
@ -141,8 +141,7 @@ class QwenModel(nn.Module):
|
|||||||
for e, layer in enumerate(self.h):
|
for e, layer in enumerate(self.h):
|
||||||
x, cache[e] = layer(x, mask, cache[e])
|
x, cache[e] = layer(x, mask, cache[e])
|
||||||
|
|
||||||
x = self.ln_f(x[:, T - 1 : T, :])
|
return self.ln_f(x), cache
|
||||||
return x, cache
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -162,3 +161,7 @@ class Model(nn.Module):
|
|||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
y, cache = self.transformer(x, mask, cache)
|
y, cache = self.transformer(x, mask, cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.transformer.h
|
||||||
|
@ -11,7 +11,6 @@ from .layers import LayerNorm
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
max_position_embeddings: int
|
|
||||||
model_type: str
|
model_type: str
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
hidden_size: int
|
hidden_size: int
|
||||||
|
@ -15,10 +15,8 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_hidden_layers: int
|
num_hidden_layers: int
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: int
|
||||||
max_position_embeddings: int = 16384
|
|
||||||
norm_epsilon: float = 1e-5
|
norm_epsilon: float = 1e-5
|
||||||
norm_type: str = "layer_norm"
|
|
||||||
vocab_size: int = 49152
|
vocab_size: int = 49152
|
||||||
rope_theta: float = 100000
|
rope_theta: float = 100000
|
||||||
tie_word_embeddings: bool = True
|
tie_word_embeddings: bool = True
|
||||||
|
199
llms/tests/test_models.py
Normal file
199
llms/tests/test_models.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
|
|
||||||
|
class TestModels(unittest.TestCase):
|
||||||
|
|
||||||
|
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||||
|
|
||||||
|
self.assertEqual(len(model.layers), num_layers)
|
||||||
|
self.assertEqual(model.model_type, model_type)
|
||||||
|
|
||||||
|
for t in [mx.float32, mx.float16]:
|
||||||
|
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
|
||||||
|
|
||||||
|
inputs = mx.array([[0, 1]])
|
||||||
|
outputs, cache = model(inputs)
|
||||||
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||||
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
|
outputs, cache = model(mx.argmax(outputs[1, :], keepdims=True), cache=cache)
|
||||||
|
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||||
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
|
def test_llama(self):
|
||||||
|
from mlx_lm.models import llama
|
||||||
|
|
||||||
|
args = llama.ModelArgs(
|
||||||
|
model_type="llama",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
)
|
||||||
|
model = llama.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_phi2(self):
|
||||||
|
from mlx_lm.models import phi
|
||||||
|
|
||||||
|
args = phi.ModelArgs()
|
||||||
|
model = phi.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gemma(self):
|
||||||
|
from mlx_lm.models import gemma
|
||||||
|
|
||||||
|
args = gemma.ModelArgs(
|
||||||
|
model_type="gemma",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
head_dim=128,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
)
|
||||||
|
model = gemma.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_mixtral(self):
|
||||||
|
from mlx_lm.models import mixtral
|
||||||
|
|
||||||
|
# Make a baby mixtral, because it will actually do the
|
||||||
|
# eval
|
||||||
|
args = mixtral.ModelArgs(
|
||||||
|
model_type="mixtral",
|
||||||
|
vocab_size=100,
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=128,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
num_local_experts=4,
|
||||||
|
)
|
||||||
|
model = mixtral.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip("requires ai2-olmo")
|
||||||
|
def test_olmo(self):
|
||||||
|
from mlx_lm.models import olmo
|
||||||
|
|
||||||
|
args = olmo.ModelArgs(
|
||||||
|
model_type="olmo",
|
||||||
|
d_model=1024,
|
||||||
|
n_layers=4,
|
||||||
|
mlp_hidden_size=2048,
|
||||||
|
n_heads=2,
|
||||||
|
vocab_size=10_000,
|
||||||
|
embedding_size=10_000,
|
||||||
|
)
|
||||||
|
model = olmo.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model,
|
||||||
|
args.model_type,
|
||||||
|
args.vocab_size,
|
||||||
|
args.n_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_qwen2(self):
|
||||||
|
from mlx_lm.models import qwen2
|
||||||
|
|
||||||
|
args = qwen2.ModelArgs(
|
||||||
|
model_type="qwen2",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
)
|
||||||
|
model = qwen2.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_qwen(self):
|
||||||
|
from mlx_lm.models import qwen
|
||||||
|
|
||||||
|
args = qwen.ModelArgs(
|
||||||
|
model_type="qwen",
|
||||||
|
)
|
||||||
|
model = qwen.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_plamo(self):
|
||||||
|
from mlx_lm.models import plamo
|
||||||
|
|
||||||
|
args = plamo.ModelArgs(
|
||||||
|
model_type="plamo",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=8,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
)
|
||||||
|
model = plamo.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_stablelm(self):
|
||||||
|
from mlx_lm.models import stablelm
|
||||||
|
|
||||||
|
args = stablelm.ModelArgs(
|
||||||
|
model_type="stablelm",
|
||||||
|
vocab_size=10_000,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
partial_rotary_factor=1.0,
|
||||||
|
intermediate_size=2048,
|
||||||
|
layer_norm_eps=1e-2,
|
||||||
|
rope_theta=10_000,
|
||||||
|
use_qkv_bias=False,
|
||||||
|
)
|
||||||
|
model = stablelm.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_starcoder2(self):
|
||||||
|
from mlx_lm.models import starcoder2
|
||||||
|
|
||||||
|
args = starcoder2.ModelArgs(
|
||||||
|
model_type="starcoder2",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
)
|
||||||
|
model = starcoder2.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
45
llms/tests/test_utils.py
Normal file
45
llms/tests/test_utils.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
from mlx_lm import utils
|
||||||
|
|
||||||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtils(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_load(self):
|
||||||
|
model, _ = utils.load(HF_MODEL_PATH)
|
||||||
|
|
||||||
|
model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True)
|
||||||
|
|
||||||
|
mx.eval(model_lazy.parameters())
|
||||||
|
|
||||||
|
p1 = model.layers[0].mlp.up_proj.weight
|
||||||
|
p2 = model_lazy.layers[0].mlp.up_proj.weight
|
||||||
|
self.assertTrue(mx.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_make_shards(self):
|
||||||
|
from mlx_lm.models import llama
|
||||||
|
|
||||||
|
args = llama.ModelArgs(
|
||||||
|
model_type="llama",
|
||||||
|
hidden_size=2048,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
intermediate_size=4096,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=30_000,
|
||||||
|
)
|
||||||
|
model = llama.Model(args)
|
||||||
|
weights = tree_flatten(model.parameters())
|
||||||
|
gb = sum(p.nbytes for _, p in weights) // 2**30
|
||||||
|
shards = utils.make_shards(dict(weights), 1)
|
||||||
|
self.assertTrue(gb <= len(shards) <= gb + 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user