mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-12 20:26:45 +08:00
Add support for falcon-mamba (#1074)
* Add support for falcon-mamba * nits * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
82e3338987
commit
3b526f0aa1
@ -221,6 +221,7 @@ Here are a few examples of Hugging Face models that work with this example:
|
|||||||
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
||||||
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
||||||
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
||||||
|
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
|
||||||
|
|
||||||
Most
|
Most
|
||||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||||
|
@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
|
|||||||
use_conv_bias: bool
|
use_conv_bias: bool
|
||||||
time_step_rank: int
|
time_step_rank: int
|
||||||
tie_word_embeddings: bool = True
|
tie_word_embeddings: bool = True
|
||||||
|
use_bcdt_rms: bool = False
|
||||||
|
mixer_rms_eps: float = 1e-6
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
||||||
@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
|
|||||||
|
|
||||||
if self.time_step_rank == "auto":
|
if self.time_step_rank == "auto":
|
||||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||||
|
if self.model_type == "falcon_mamba":
|
||||||
|
self.use_bcdt_rms = True
|
||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
|
|||||||
self.intermediate_size = args.intermediate_size
|
self.intermediate_size = args.intermediate_size
|
||||||
self.time_step_rank = int(args.time_step_rank)
|
self.time_step_rank = int(args.time_step_rank)
|
||||||
self.use_conv_bias = args.use_conv_bias
|
self.use_conv_bias = args.use_conv_bias
|
||||||
|
self.use_bcdt_rms = args.use_bcdt_rms
|
||||||
|
if self.use_bcdt_rms:
|
||||||
|
self.mixer_norm = lambda x: mx.fast.rms_norm(
|
||||||
|
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
|
||||||
|
)
|
||||||
|
|
||||||
self.in_proj = nn.Linear(
|
self.in_proj = nn.Linear(
|
||||||
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
||||||
@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
|
|||||||
],
|
],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
|
if self.use_bcdt_rms:
|
||||||
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
delta = nn.softplus(self.dt_proj(delta))
|
delta = nn.softplus(self.dt_proj(delta))
|
||||||
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||||
if state is not None:
|
if state is not None:
|
||||||
|
@ -29,6 +29,7 @@ from .tuner.utils import load_adapters
|
|||||||
MODEL_REMAPPING = {
|
MODEL_REMAPPING = {
|
||||||
"mistral": "llama", # mistral is compatible with llama
|
"mistral": "llama", # mistral is compatible with llama
|
||||||
"phi-msft": "phixtral",
|
"phi-msft": "phixtral",
|
||||||
|
"falcon_mamba": "mamba",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAX_FILE_SIZE_GB = 5
|
MAX_FILE_SIZE_GB = 5
|
||||||
|
Loading…
Reference in New Issue
Block a user