This commit is contained in:
Goekdeniz-Guelmez 2024-12-10 18:20:13 +01:00
parent 80e88b4f4d
commit 184d3d3267

View File

@ -76,16 +76,16 @@ def ssd(x, A, B, C, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state
@ -173,7 +173,7 @@ class Mamba2Block(nn.Module):
batch_size, seq_len, _ = u.shape
# Project input
proj = self.in_proj(u) # [batch, seq_len, d_in_proj]
proj = self.in_proj(u) # [batch, seq_len, d_in_proj]
# Calculate split indices and slice tensors
z = proj[..., :self.d_inner]
@ -214,9 +214,9 @@ class Mamba2Block(nn.Module):
# Compute dA
dA = -mx.exp(self.A_log) # [n_heads]
dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads)) # Ensure correct shape
dA = mx.exp(mx.expand_dims(dt * mx.expand_dims(dA, 0), -1)) # [batch, seq_len, n_heads, 1]
dA = mx.expand_dims(dA, -1) # [batch, seq_len, n_heads, 1, 1]
dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads)) # Ensure correct shape
dA = mx.exp(mx.expand_dims(dt * mx.expand_dims(dA, 0), -1)) # [batch, seq_len, n_heads, 1]
dA = mx.expand_dims(dA, -1) # [batch, seq_len, n_heads, 1, 1]
# Process sequence
next_state = prev_state
@ -224,26 +224,26 @@ class Mamba2Block(nn.Module):
for t in range(seq_len):
# Get current step tensors
xt = x[:, t] # [batch, n_heads, d_head]
Bt = B[:, t] # [batch, n_heads, d_state]
Ct = C[:, t] # [batch, n_heads, d_state]
dAt = dA[:, t] # [batch, n_heads, 1, 1]
xt = x[:, t] # [batch, n_heads, d_head]
Bt = B[:, t] # [batch, n_heads, d_state]
Ct = C[:, t] # [batch, n_heads, d_state]
dAt = dA[:, t] # [batch, n_heads, 1, 1]
# Update state
next_state = (
next_state * dAt + # Broadcasting: [batch, n_heads, d_head, d_state] * [batch, n_heads, 1, 1]
next_state * dAt + # Broadcasting: [batch, n_heads, d_head, d_state] * [batch, n_heads, 1, 1]
mx.matmul(
mx.expand_dims(xt, -1), # [batch, n_heads, d_head, 1]
mx.expand_dims(Bt, -2) # [batch, n_heads, 1, d_state]
mx.expand_dims(xt, -1), # [batch, n_heads, d_head, 1]
mx.expand_dims(Bt, -2) # [batch, n_heads, 1, d_state]
)
)
# Compute output
yt = mx.matmul(
next_state, # [batch, n_heads, d_head, d_state]
mx.expand_dims(Ct, -1) # [batch, n_heads, d_state, 1]
next_state, # [batch, n_heads, d_head, d_state]
mx.expand_dims(Ct, -1) # [batch, n_heads, d_state, 1]
)
yt = mx.squeeze(yt, -1) # [batch, n_heads, d_head]
yt = mx.squeeze(yt, -1) # [batch, n_heads, d_head]
yt = yt + xt * mx.expand_dims(self.D, -1)
# Reshape and normalize