Add support for falcon-mamba (#1074)

* Add support for falcon-mamba

* nits

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
ilyasch2
2024-11-05 00:23:30 +04:00
committed by GitHub
parent 82e3338987
commit 3b526f0aa1
3 changed files with 13 additions and 0 deletions

View File

@@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
use_bcdt_rms: bool = False
mixer_rms_eps: float = 1e-6
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
@@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
if self.model_type == "falcon_mamba":
self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module):
@@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.use_bcdt_rms = args.use_bcdt_rms
if self.use_bcdt_rms:
self.mixer_norm = lambda x: mx.fast.rms_norm(
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
)
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
@@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
],
axis=-1,
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None: