From 31cb8cac94d1e2a2a9a6435c32164f3044122ece Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 4 Nov 2024 12:19:20 -0800 Subject: [PATCH] nits --- llms/mlx_lm/models/mamba.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 5181f045..f2414660 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -89,7 +89,9 @@ class MambaBlock(nn.Module): self.use_conv_bias = args.use_conv_bias self.use_bcdt_rms = args.use_bcdt_rms if self.use_bcdt_rms: - self.mixer_rms_eps = args.mixer_rms_eps + self.mixer_norm = lambda x: mx.fast.rms_norm( + x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps + ) self.in_proj = nn.Linear( self.hidden_size, self.intermediate_size * 2, bias=args.use_bias @@ -134,7 +136,7 @@ class MambaBlock(nn.Module): axis=-1, ) if self.use_bcdt_rms: - delta, B, C = mx.fast.rms_norm(delta, mx.ones_like(delta[1]), eps = self.mixer_rms_eps), mx.fast.rms_norm(B, mx.ones_like(B[1]), eps = self.mixer_rms_eps), mx.fast.rms_norm(C, mx.ones_like(C[1]), eps = self.mixer_rms_eps) + delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) if state is not None: @@ -223,4 +225,4 @@ class Model(nn.Module): @property def layers(self): - return self.backbone.layers \ No newline at end of file + return self.backbone.layers