mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
![]() |
# Copyright © 2024 Apple Inc.
|
||
|
|
||
|
import unittest
|
||
|
|
||
|
import mlx.core as mx
|
||
|
from mlx.utils import tree_flatten
|
||
|
from mlx_lm import utils
|
||
|
|
||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||
|
|
||
|
|
||
|
class TestUtils(unittest.TestCase):
|
||
|
|
||
|
def test_load(self):
|
||
|
model, _ = utils.load(HF_MODEL_PATH)
|
||
|
|
||
|
model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True)
|
||
|
|
||
|
mx.eval(model_lazy.parameters())
|
||
|
|
||
|
p1 = model.layers[0].mlp.up_proj.weight
|
||
|
p2 = model_lazy.layers[0].mlp.up_proj.weight
|
||
|
self.assertTrue(mx.allclose(p1, p2))
|
||
|
|
||
|
def test_make_shards(self):
|
||
|
from mlx_lm.models import llama
|
||
|
|
||
|
args = llama.ModelArgs(
|
||
|
model_type="llama",
|
||
|
hidden_size=2048,
|
||
|
num_hidden_layers=32,
|
||
|
intermediate_size=4096,
|
||
|
num_attention_heads=32,
|
||
|
rms_norm_eps=1e-5,
|
||
|
vocab_size=30_000,
|
||
|
)
|
||
|
model = llama.Model(args)
|
||
|
weights = tree_flatten(model.parameters())
|
||
|
gb = sum(p.nbytes for _, p in weights) // 2**30
|
||
|
shards = utils.make_shards(dict(weights), 1)
|
||
|
self.assertTrue(gb <= len(shards) <= gb + 1)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|