Adding support for mamba (#940)

* initial commit

* initial commit

* Adding first lines

* adding x, and dt projection layers

* adding the clamping mechanism

* First succesful inference

* last commit for today - added custom geenrate function and it works as expected, will try training and then with loading a model from the hub

* clean up

* save up

* almost

* update

* update

* fixed cache handeling

* fixed loading

* added seperate generat_step method in the model and also in the utils to automaticaly use the generate step mthod in the model class

* quick update

* still not working

* save

* still not working

* initial commit

* utils.py logits = logits[:, -1, :] TypeError: tuple indices must be integers or slices, not tuple

* update

* update

* Fixing the Batching Depfwise Comnvolution and multi token input

* fixing generate and logits outputs

* Done!

* Fixing the cache handling, generating works now trying training

* update ACKNOWLEDGEMENTS

* removing the model_type if stuff in the _step loop in generate_step and adding MambaCache in base.py for training easier generations and removing mamba in tuner/utils.

* quick clean up

* update trainer/utils for right initialisation of the layers for LoRA, but not working.

* clean up

* Forther update to trainer/utils for correct layer selection. Successfull training

* removing extra mamba-infer.py file

* clean up, reformating will come later

* reformat and big clean up, final commit

* some speedups and cleanups

* fix test

* nits

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gökdeniz Gülmez
2024-09-28 16:02:53 +02:00
committed by GitHub
parent e776c970f7
commit 76710f61af
4 changed files with 263 additions and 8 deletions

View File

@@ -52,7 +52,6 @@ def linear_to_lora_layers(
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
num_layers = len(model.layers)
if num_lora_layers < 0:
@@ -140,6 +139,15 @@ def linear_to_lora_layers(
"self_attn.kv_b_proj",
]
)
elif model.model_type == "mamba":
keys = set(
[
"mixer.in_proj",
"mixer.x_proj",
"mixer.dt_proj",
"mixer.out_proj",
]
)
else:
raise ValueError(f"Lora does not support {model.model_type}")