From 55485b98e88bd9131c06cce59a30649d479c733c Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 22 Oct 2024 21:23:47 +0200 Subject: [PATCH] update --- llms/mlx_lm/models/cache.py | 21 +++++++++++ llms/mlx_lm/models/mamba2-prch.py | 1 + llms/mlx_lm/models/mamba2.py | 62 ++++++++++++++++++++----------- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index a6a56e0a..32343ae0 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -338,3 +338,24 @@ class MambaCache(_BaseCache): @state.setter def state(self, v): self.cache = v + + +class Mamba2Cache(_BaseCache): + """Cache for Mamba model inference containing conv cache and SSM state.""" + conv_cache: Optional[mx.array] = None + ssm_state: Optional[mx.array] = None + + def __getitem__(self, idx: int) -> Optional[mx.array]: + if idx == 0: + return self.conv_cache + elif idx == 1: + return self.ssm_state + raise IndexError("Cache index must be 0 or 1") + + def __setitem__(self, idx: int, value: Optional[mx.array]): + if idx == 0: + self.conv_cache = value + elif idx == 1: + self.ssm_state = value + else: + raise IndexError("Cache index must be 0 or 1") \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2-prch.py b/llms/mlx_lm/models/mamba2-prch.py index da5de3e9..f9bd6797 100644 --- a/llms/mlx_lm/models/mamba2-prch.py +++ b/llms/mlx_lm/models/mamba2-prch.py @@ -193,6 +193,7 @@ class Mamba2(nn.Module): self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) self.D = nn.Parameter(torch.empty(args.nheads, device=device)) + self.norm = RMSNorm(args.d_inner, device=device) self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 3c4b4e7c..2e0c6a59 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -188,43 +188,52 @@ class Mamba2Block(nn.Module): if cache is not None and self.args.use_cache: return self.step(u, cache) + # Calculate sizes + d_model = self.args.intermediate_size + d_state = self.args.state_size + n_heads = self.args.num_heads + + # Compute A A = -mx.exp(self.A_log) + + # Project input zxbcdt = self.in_proj(u) + # Correct splits for z, xBC, dt splits = [ - self.args.intermediate_size, - self.args.intermediate_size + 2 * self.args.state_size, - self.args.num_heads, + d_model, # z + d_model + 2 * d_state, # xBC (delta, B, C concatenated) + n_heads # dt ] - - z, xBC, dt = mx.split(zxbcdt, splits, axis=-1) + # Split using cumulative indices + z = zxbcdt[:, :, :splits[0]] + xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]] + dt = zxbcdt[:, :, -splits[2]:] + + # Process dt dt = mx.clip( nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max ) - dt = mx.maximum(dt, self.args.time_step_floor) + # Process convolution xBC = silu(self.conv1d(xBC)) - xBC_parts = mx.split( - xBC, - [self.args.intermediate_size, self.args.state_size, self.args.state_size], - axis=-1 - ) + # Split convolved xBC into x, B, C + x = xBC[:, :, :d_model] + B = xBC[:, :, d_model:d_model + d_state] + C = xBC[:, :, -d_state:] - x = xBC_parts[0] - B = xBC_parts[1] - C = xBC_parts[2] - - # Replace rearrange with reshape and transpose + # Reshape for SSM computation b, l, hp = x.shape h = self.args.num_heads p = hp // h x = mx.reshape(x, (b, l, h, p)) - + + # Compute SSM y, ssm_state = ssd( x * mx.expand_dims(dt, -1), A * dt, @@ -232,23 +241,34 @@ class Mamba2Block(nn.Module): C, self.args.chunk_size ) - + + # Add skip connection y = y + x * mx.expand_dims(self.D, -1) - # Replace rearrange with reshape - y = mx.reshape(y, (b, l, h * p)) + # Reshape back + y = mx.reshape(y, (b, l, h * p)) + + # Apply norm and projection y = self.norm(y + z) y = self.out_proj(y) + # Update cache if needed if cache is not None and self.args.use_cache: cache[1] = ssm_state + # Cast if needed if self.args.residual_in_fp32: - y = mx.cast(y, mx.float32) + y.astype(mx.float32) return y def step(self, u: mx.array, cache: MambaCache): + """ + Process single or multiple tokens while maintaining state. + Args: + u: Input tensor of shape (batch_size, seq_len, hidden_size) + cache: MambaCache object containing conv cache and ssm state + """ batch_size = u.shape[0] seq_len = u.shape[1] outputs = []