mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
nits
This commit is contained in:
parent
49813f524f
commit
31cb8cac94
@ -89,7 +89,9 @@ class MambaBlock(nn.Module):
|
|||||||
self.use_conv_bias = args.use_conv_bias
|
self.use_conv_bias = args.use_conv_bias
|
||||||
self.use_bcdt_rms = args.use_bcdt_rms
|
self.use_bcdt_rms = args.use_bcdt_rms
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
self.mixer_rms_eps = args.mixer_rms_eps
|
self.mixer_norm = lambda x: mx.fast.rms_norm(
|
||||||
|
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
|
||||||
|
)
|
||||||
|
|
||||||
self.in_proj = nn.Linear(
|
self.in_proj = nn.Linear(
|
||||||
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
||||||
@ -134,7 +136,7 @@ class MambaBlock(nn.Module):
|
|||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
delta, B, C = mx.fast.rms_norm(delta, mx.ones_like(delta[1]), eps = self.mixer_rms_eps), mx.fast.rms_norm(B, mx.ones_like(B[1]), eps = self.mixer_rms_eps), mx.fast.rms_norm(C, mx.ones_like(C[1]), eps = self.mixer_rms_eps)
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
delta = nn.softplus(self.dt_proj(delta))
|
delta = nn.softplus(self.dt_proj(delta))
|
||||||
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||||
if state is not None:
|
if state is not None:
|
||||||
@ -223,4 +225,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.backbone.layers
|
return self.backbone.layers
|
||||||
|
Loading…
Reference in New Issue
Block a user