mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Adding support for mamba (#940)
* initial commit * initial commit * Adding first lines * adding x, and dt projection layers * adding the clamping mechanism * First succesful inference * last commit for today - added custom geenrate function and it works as expected, will try training and then with loading a model from the hub * clean up * save up * almost * update * update * fixed cache handeling * fixed loading * added seperate generat_step method in the model and also in the utils to automaticaly use the generate step mthod in the model class * quick update * still not working * save * still not working * initial commit * utils.py logits = logits[:, -1, :] TypeError: tuple indices must be integers or slices, not tuple * update * update * Fixing the Batching Depfwise Comnvolution and multi token input * fixing generate and logits outputs * Done! * Fixing the cache handling, generating works now trying training * update ACKNOWLEDGEMENTS * removing the model_type if stuff in the _step loop in generate_step and adding MambaCache in base.py for training easier generations and removing mamba in tuner/utils. * quick clean up * update trainer/utils for right initialisation of the layers for LoRA, but not working. * clean up * Forther update to trainer/utils for correct layer selection. Successfull training * removing extra mamba-infer.py file * clean up, reformating will come later * reformat and big clean up, final commit * some speedups and cleanups * fix test * nits * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import unittest
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models.base import KVCache, RotatingKVCache
|
||||
from mlx_lm.utils import make_kv_caches
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
@@ -100,13 +101,7 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
cache = make_kv_caches(model)
|
||||
outputs = model(inputs, cache)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
@@ -397,6 +392,26 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_mamba(self):
|
||||
from mlx_lm.models import mamba
|
||||
|
||||
args = mamba.ModelArgs(
|
||||
model_type="mamba",
|
||||
vocab_size=10000,
|
||||
use_bias=False,
|
||||
use_conv_bias=True,
|
||||
conv_kernel=4,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=24,
|
||||
state_size=16,
|
||||
intermediate_size=1536,
|
||||
time_step_rank=48,
|
||||
)
|
||||
model = mamba.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gpt2(self):
|
||||
from mlx_lm.models import gpt2
|
||||
|
||||
|
Reference in New Issue
Block a user