From 184d3d3267a1fd6521878a700dcbfe1b1302671c Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 10 Dec 2024 18:20:13 +0100 Subject: [PATCH] clean up --- llms/mlx_lm/models/mamba2.py | 42 ++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 822ebe8a..b85d3667 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -76,16 +76,16 @@ def ssd(x, A, B, C, chunk_size): chunk = slice(i, min(i + chunk_size, seqlen)) dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] - x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] - B_chunk = B[:, chunk] # [batch, chunk_size, state_size] - dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] + x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] + x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] + B_chunk = B[:, chunk] # [batch, chunk_size, state_size] + dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] state = state * mx.expand_dims(dA, axis=-1) + dBx - C_chunk = C[:, chunk] # [batch, chunk_size, state_size] - y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] - y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] + C_chunk = C[:, chunk] # [batch, chunk_size, state_size] + y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] + y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] outputs.append(y) return mx.concatenate(outputs, axis=1), state @@ -173,7 +173,7 @@ class Mamba2Block(nn.Module): batch_size, seq_len, _ = u.shape # Project input - proj = self.in_proj(u) # [batch, seq_len, d_in_proj] + proj = self.in_proj(u) # [batch, seq_len, d_in_proj] # Calculate split indices and slice tensors z = proj[..., :self.d_inner] @@ -214,9 +214,9 @@ class Mamba2Block(nn.Module): # Compute dA dA = -mx.exp(self.A_log) # [n_heads] - dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads)) # Ensure correct shape - dA = mx.exp(mx.expand_dims(dt * mx.expand_dims(dA, 0), -1)) # [batch, seq_len, n_heads, 1] - dA = mx.expand_dims(dA, -1) # [batch, seq_len, n_heads, 1, 1] + dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads)) # Ensure correct shape + dA = mx.exp(mx.expand_dims(dt * mx.expand_dims(dA, 0), -1)) # [batch, seq_len, n_heads, 1] + dA = mx.expand_dims(dA, -1) # [batch, seq_len, n_heads, 1, 1] # Process sequence next_state = prev_state @@ -224,26 +224,26 @@ class Mamba2Block(nn.Module): for t in range(seq_len): # Get current step tensors - xt = x[:, t] # [batch, n_heads, d_head] - Bt = B[:, t] # [batch, n_heads, d_state] - Ct = C[:, t] # [batch, n_heads, d_state] - dAt = dA[:, t] # [batch, n_heads, 1, 1] + xt = x[:, t] # [batch, n_heads, d_head] + Bt = B[:, t] # [batch, n_heads, d_state] + Ct = C[:, t] # [batch, n_heads, d_state] + dAt = dA[:, t] # [batch, n_heads, 1, 1] # Update state next_state = ( - next_state * dAt + # Broadcasting: [batch, n_heads, d_head, d_state] * [batch, n_heads, 1, 1] + next_state * dAt + # Broadcasting: [batch, n_heads, d_head, d_state] * [batch, n_heads, 1, 1] mx.matmul( - mx.expand_dims(xt, -1), # [batch, n_heads, d_head, 1] - mx.expand_dims(Bt, -2) # [batch, n_heads, 1, d_state] + mx.expand_dims(xt, -1), # [batch, n_heads, d_head, 1] + mx.expand_dims(Bt, -2) # [batch, n_heads, 1, d_state] ) ) # Compute output yt = mx.matmul( - next_state, # [batch, n_heads, d_head, d_state] - mx.expand_dims(Ct, -1) # [batch, n_heads, d_state, 1] + next_state, # [batch, n_heads, d_head, d_state] + mx.expand_dims(Ct, -1) # [batch, n_heads, d_state, 1] ) - yt = mx.squeeze(yt, -1) # [batch, n_heads, d_head] + yt = mx.squeeze(yt, -1) # [batch, n_heads, d_head] yt = yt + xt * mx.expand_dims(self.D, -1) # Reshape and normalize