From 3b526f0aa1219fae662a86f012dbda82045f4fb0 Mon Sep 17 00:00:00 2001 From: ilyasch2 <104485953+ilyasch2@users.noreply.github.com> Date: Tue, 5 Nov 2024 00:23:30 +0400 Subject: [PATCH] Add support for falcon-mamba (#1074) * Add support for falcon-mamba * nits * nit --------- Co-authored-by: Awni Hannun --- llms/README.md | 1 + llms/mlx_lm/models/mamba.py | 11 +++++++++++ llms/mlx_lm/utils.py | 1 + 3 files changed, 13 insertions(+) diff --git a/llms/README.md b/llms/README.md index f539988a..0e7dc7fb 100644 --- a/llms/README.md +++ b/llms/README.md @@ -221,6 +221,7 @@ Here are a few examples of Hugging Face models that work with this example: - [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct) - [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) - [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) +- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) Most [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 84f498e9..f2414660 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs): use_conv_bias: bool time_step_rank: int tie_word_embeddings: bool = True + use_bcdt_rms: bool = False + mixer_rms_eps: float = 1e-6 def __post_init__(self): if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): @@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs): if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) + if self.model_type == "falcon_mamba": + self.use_bcdt_rms = True class DepthWiseConv1d(nn.Module): @@ -83,6 +87,11 @@ class MambaBlock(nn.Module): self.intermediate_size = args.intermediate_size self.time_step_rank = int(args.time_step_rank) self.use_conv_bias = args.use_conv_bias + self.use_bcdt_rms = args.use_bcdt_rms + if self.use_bcdt_rms: + self.mixer_norm = lambda x: mx.fast.rms_norm( + x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps + ) self.in_proj = nn.Linear( self.hidden_size, self.intermediate_size * 2, bias=args.use_bias @@ -126,6 +135,8 @@ class MambaBlock(nn.Module): ], axis=-1, ) + if self.use_bcdt_rms: + delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) if state is not None: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b9fc202d..7b440db6 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -29,6 +29,7 @@ from .tuner.utils import load_adapters MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama "phi-msft": "phixtral", + "falcon_mamba": "mamba", } MAX_FILE_SIZE_GB = 5