mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-20 02:21:15 +08:00
inference with the origional mamba2 model woirks but still not with codestral. working:
rokyang/mamba2-130m-hf rokyang/mamba2-370m-hf rokyang/mamba2-780m-hf rokyang/mamba2-1.3b-hf rokyang/mamba2-2.7b-hf
This commit is contained in:
parent
be4bc7a090
commit
eb432f4b7d
@ -28,7 +28,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rms_norm: bool
|
rms_norm: bool
|
||||||
chunk_size: int
|
chunk_size: int
|
||||||
tie_word_embeddings: bool
|
tie_word_embeddings: bool
|
||||||
intermediate_size: int
|
|
||||||
time_step_limit: Tuple[float, float]
|
time_step_limit: Tuple[float, float]
|
||||||
time_step_rank: Union[int, str]
|
time_step_rank: Union[int, str]
|
||||||
time_step_min: float
|
time_step_min: float
|
||||||
@ -106,15 +105,6 @@ class Mamba2Block(nn.Module):
|
|||||||
# Input projection
|
# Input projection
|
||||||
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
|
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
|
||||||
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
|
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
|
||||||
|
|
||||||
# Improved initialization of dt
|
|
||||||
dt = mx.exp(
|
|
||||||
mx.random.uniform(
|
|
||||||
low=math.log(args.time_step_min),
|
|
||||||
high=math.log(args.time_step_max),
|
|
||||||
shape=(self.n_heads,)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
|
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||||
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
|
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||||
|
Loading…
Reference in New Issue
Block a user