mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
Add plamo-2-1b model (#1283)
* Add pfnet/plamo-2-1b * Fix cache.py to support non-top level layers * Use mlx's BaseModelArgs * Fix model * Use sanitize() * Remove unnecessary changes * Add plamo2.py * Apply formatter * Fix some part * Allow a cache obj defined externally * Fix channel first weights to channel last for right use of MLX's conv1d * Remove unused code part * Give all inputs when it's the first time call of model * Fix import * Include .jsonl files to download from Huggingface hub * Fix reference to layers * Remove unnecessary code and add a test for plamo2 * Do not pass mask to prepare_inputs_for_generation * Fix to use repeat instead of tile * Add state property to PlamoCache * Add __iter__ and __next__ methods to PlamoCache * cleanup * cleanup * fix --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -183,7 +183,7 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
if model_type != "mamba":
|
||||
if model_type not in ("mamba", "plamo2"):
|
||||
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
|
||||
outputs = model(inputs, mask=mask)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
@@ -372,6 +372,23 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_plamo2(self):
|
||||
from mlx_lm.models import plamo2
|
||||
|
||||
args = plamo2.ModelArgs(
|
||||
model_type="plamo2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=8,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = plamo2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_stablelm(self):
|
||||
from mlx_lm.models import stablelm
|
||||
|
||||
|
Reference in New Issue
Block a user