Add support for falcon-mamba

This commit is contained in:
Ilyas Chahed 2024-10-27 13:28:49 +00:00
parent ab4bf05c6e
commit 49813f524f
3 changed files with 13 additions and 2 deletions

View File

@ -221,7 +221,8 @@ Here are a few examples of Hugging Face models that work with this example:
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
- [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b)
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),

View File

@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
use_bcdt_rms: bool = False
mixer_rms_eps: float = 1e-6
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
if self.model_type == "falcon_mamba":
self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module):
@ -83,6 +87,9 @@ class MambaBlock(nn.Module):
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.use_bcdt_rms = args.use_bcdt_rms
if self.use_bcdt_rms:
self.mixer_rms_eps = args.mixer_rms_eps
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
@ -126,6 +133,8 @@ class MambaBlock(nn.Module):
],
axis=-1,
)
if self.use_bcdt_rms:
delta, B, C = mx.fast.rms_norm(delta, mx.ones_like(delta[1]), eps = self.mixer_rms_eps), mx.fast.rms_norm(B, mx.ones_like(B[1]), eps = self.mixer_rms_eps), mx.fast.rms_norm(C, mx.ones_like(C[1]), eps = self.mixer_rms_eps)
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
@ -214,4 +223,4 @@ class Model(nn.Module):
@property
def layers(self):
return self.backbone.layers
return self.backbone.layers

View File

@ -28,6 +28,7 @@ from .tuner.utils import load_adapters
MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
"falcon_mamba": "mamba",
}
MAX_FILE_SIZE_GB = 5