mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-14 05:36:38 +08:00
getting reall closer:
python -m mlx_lm.generate --model /Users/gokdenizgulmez/Desktop/Mamba-Codestral-7B-v0.1-4bit --prompt "# A function that computes fibonacci def fibonacci(" -m 64 ========== n): print(f"{os.path.abspath(".")/data/data/data/com.android.launcher.png) ## 🙌🏼 🙌🙌🙌🙌🙌🙌 class _State(Enum): def __init__ (self ========== Prompt: 16 tokens, 84.547 tokens-per-sec Generation: 64 tokens, 13.774 tokens-per-sec Peak memory: 4.139 GB
This commit is contained in:
parent
eb432f4b7d
commit
5a6ada2df0
@ -33,8 +33,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
time_step_min: float
|
time_step_min: float
|
||||||
time_step_max: float
|
time_step_max: float
|
||||||
time_step_floor: float
|
time_step_floor: float
|
||||||
A_init_min: float = 1.0
|
norm_before_gate: bool = True
|
||||||
A_init_max: float = 16.0
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not hasattr(self, "intermediate_size"):
|
if not hasattr(self, "intermediate_size"):
|
||||||
@ -46,17 +45,29 @@ class ModelArgs(BaseModelArgs):
|
|||||||
|
|
||||||
|
|
||||||
class MambaRMSNormGated(nn.Module):
|
class MambaRMSNormGated(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6, norm_before_gate=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = mx.ones((hidden_size,))
|
self.weight = mx.ones((hidden_size,))
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
self.norm_before_gate = norm_before_gate
|
||||||
|
|
||||||
def __call__(self, hidden_states, gate=None):
|
def rms_norm(self, x):
|
||||||
if gate is not None:
|
variance = mx.mean(x ** 2, axis=-1, keepdims=True)
|
||||||
hidden_states = hidden_states * nn.silu(gate)
|
x = x * mx.rsqrt(variance + self.variance_epsilon)
|
||||||
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
|
return self.weight * x
|
||||||
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return self.weight * hidden_states
|
def __call__(self, x, z=None):
|
||||||
|
if z is None:
|
||||||
|
return self.rms_norm(x)
|
||||||
|
|
||||||
|
if self.norm_before_gate:
|
||||||
|
x = self.rms_norm(x)
|
||||||
|
x = x * nn.silu(z)
|
||||||
|
else:
|
||||||
|
x = x * nn.silu(z)
|
||||||
|
x = self.rms_norm(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def silu(x):
|
def silu(x):
|
||||||
@ -86,12 +97,71 @@ class DepthWiseConv1d(nn.Module):
|
|||||||
return y, x[:, -K + 1:, :]
|
return y, x[:, -K + 1:, :]
|
||||||
|
|
||||||
|
|
||||||
|
def ssd_forward_attn(
|
||||||
|
x: mx.array,
|
||||||
|
dt: mx.array,
|
||||||
|
A: mx.array,
|
||||||
|
B: mx.array,
|
||||||
|
C: mx.array,
|
||||||
|
D: mx.array,
|
||||||
|
dt_bias: mx.array,
|
||||||
|
dt_min: float,
|
||||||
|
dt_max: float,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
b, l, h, dh = x.shape
|
||||||
|
_, _, g, _ = B.shape
|
||||||
|
|
||||||
|
if dt_bias is not None:
|
||||||
|
dt = dt + dt_bias.reshape(1, 1, -1)
|
||||||
|
|
||||||
|
dt = nn.softplus(dt)
|
||||||
|
dt = mx.clip(dt, a_min=dt_min, a_max=dt_max)
|
||||||
|
|
||||||
|
B = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2)
|
||||||
|
C = mx.swapaxes(C, 1, 2)
|
||||||
|
|
||||||
|
CB = C @ B
|
||||||
|
CB = mx.repeat(CB, repeats=h // g, axis=1)
|
||||||
|
|
||||||
|
dtA = dt * A.reshape(1, 1, -1)
|
||||||
|
dtA = mx.swapaxes(dtA, 1, 2)
|
||||||
|
|
||||||
|
decay = mx.exp(segsum(dtA))
|
||||||
|
|
||||||
|
surrogate_attention_matrix = mx.tril(CB * decay, 0)
|
||||||
|
|
||||||
|
dtx = dt.reshape(b, l, h, 1) * x
|
||||||
|
y = surrogate_attention_matrix @ dtx.swapaxes(1, 2)
|
||||||
|
y = mx.swapaxes(y, 1, 2)
|
||||||
|
|
||||||
|
decay = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1)
|
||||||
|
B = mx.repeat(B, h // g, axis=1).swapaxes(2, 3)
|
||||||
|
dtxdecay = dtx * decay
|
||||||
|
dtxdecay = dtxdecay.swapaxes(1, 2).swapaxes(2, 3)
|
||||||
|
next_state = dtxdecay @ B
|
||||||
|
|
||||||
|
if D is not None:
|
||||||
|
y += x * D.reshape(1, 1, h, 1)
|
||||||
|
|
||||||
|
y = y.reshape(b, l, h * dh)
|
||||||
|
|
||||||
|
return y, next_state
|
||||||
|
|
||||||
|
|
||||||
|
def segsum(x):
|
||||||
|
l = x.shape[-1]
|
||||||
|
x = mx.repeat(x[..., None], l, axis=-1)
|
||||||
|
x = mx.tril(x, -1)
|
||||||
|
x_segsum = mx.cumsum(x, axis=-2)
|
||||||
|
return x_segsum
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Block(nn.Module):
|
class Mamba2Block(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
# Same dimensions as before
|
# Dimensions
|
||||||
self.d_model = args.hidden_size
|
self.d_model = args.hidden_size
|
||||||
self.d_state = args.state_size
|
self.d_state = args.state_size
|
||||||
self.d_conv = args.conv_kernel
|
self.d_conv = args.conv_kernel
|
||||||
@ -106,14 +176,12 @@ class Mamba2Block(nn.Module):
|
|||||||
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
|
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
|
||||||
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
|
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
|
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||||
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
|
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||||
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
|
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
|
||||||
|
|
||||||
# Same D initialization
|
# Convolution
|
||||||
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
|
|
||||||
|
|
||||||
# Convolution with proper initialization
|
|
||||||
self.conv1d = DepthWiseConv1d(
|
self.conv1d = DepthWiseConv1d(
|
||||||
channels=self.d_inner + 2 * self.n_groups * self.d_state,
|
channels=self.d_inner + 2 * self.n_groups * self.d_state,
|
||||||
kernel_size=self.d_conv,
|
kernel_size=self.d_conv,
|
||||||
@ -122,7 +190,11 @@ class Mamba2Block(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Output projections
|
# Output projections
|
||||||
self.norm = MambaRMSNormGated(self.d_inner, eps=args.layer_norm_epsilon)
|
self.norm = MambaRMSNormGated(
|
||||||
|
self.d_inner,
|
||||||
|
eps=args.layer_norm_epsilon,
|
||||||
|
norm_before_gate=args.norm_before_gate
|
||||||
|
)
|
||||||
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
|
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
|
||||||
|
|
||||||
def __call__(self, u: mx.array, cache=None):
|
def __call__(self, u: mx.array, cache=None):
|
||||||
@ -131,103 +203,59 @@ class Mamba2Block(nn.Module):
|
|||||||
cache = [None, None]
|
cache = [None, None]
|
||||||
|
|
||||||
# Project input
|
# Project input
|
||||||
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
zxBCdt = self.in_proj(u)
|
||||||
A = -mx.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
|
||||||
|
|
||||||
|
# Split projections
|
||||||
z, xBC, dt = mx.split(
|
z, xBC, dt = mx.split(
|
||||||
zxbcdt,
|
zxBCdt,
|
||||||
indices_or_sections=[
|
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
|
||||||
self.d_inner,
|
|
||||||
self.d_inner + (2 * self.n_groups * self.d_state + self.d_inner)
|
|
||||||
],
|
|
||||||
axis=-1
|
axis=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process dt
|
# Process convolution
|
||||||
dt = nn.softplus(dt + self.dt_bias) # (B, L, nheads)
|
xBC, conv_state = self.conv1d(xBC, cache[0])
|
||||||
|
|
||||||
# Conv1d and activation
|
|
||||||
xBC, conv_state = self.conv1d(xBC, cache[0] if cache else None)
|
|
||||||
xBC = silu(xBC)
|
xBC = silu(xBC)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
cache[0] = conv_state
|
cache[0] = conv_state
|
||||||
|
|
||||||
xBC = xBC[:, :seq_len, :]
|
xBC = xBC[:, :seq_len, :]
|
||||||
|
|
||||||
# Split conv output and reshape
|
# Split and reshape conv output
|
||||||
x, B, C = mx.split(
|
x, B, C = mx.split(
|
||||||
xBC,
|
xBC,
|
||||||
indices_or_sections=[
|
[self.d_inner, self.d_inner + self.d_state * self.n_groups],
|
||||||
self.d_inner,
|
|
||||||
self.d_inner + self.n_groups * self.d_state
|
|
||||||
],
|
|
||||||
axis=-1
|
axis=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape tensors
|
# Reshape for SSM processing
|
||||||
|
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
|
||||||
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
|
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
|
||||||
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
||||||
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, -1))
|
|
||||||
|
|
||||||
# Initialize state
|
# Get parameters for attention computation
|
||||||
if cache and cache[1] is not None:
|
A = -mx.exp(self.A_log)
|
||||||
prev_state = cache[1]
|
|
||||||
else:
|
|
||||||
prev_state = mx.zeros((batch_size, self.n_heads, self.d_head, self.d_state))
|
|
||||||
|
|
||||||
# Compute dA
|
# Compute parallel attention
|
||||||
dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads))
|
y, next_state = ssd_forward_attn(
|
||||||
dA = mx.exp(dt * mx.expand_dims(A, axis=(0, 1)))
|
x=x,
|
||||||
|
dt=dt,
|
||||||
|
A=A,
|
||||||
|
B=B,
|
||||||
|
C=C,
|
||||||
|
D=self.D,
|
||||||
|
dt_bias=self.dt_bias,
|
||||||
|
dt_min=self.args.time_step_min,
|
||||||
|
dt_max=self.args.time_step_max,
|
||||||
|
)
|
||||||
|
|
||||||
# Process sequence in chunks
|
# Update cache
|
||||||
chunk_size = self.chunk_size
|
|
||||||
outputs = []
|
|
||||||
next_state = prev_state
|
|
||||||
|
|
||||||
# Process in chunks
|
|
||||||
for chunk_start in range(0, seq_len, chunk_size):
|
|
||||||
chunk_end = min(chunk_start + chunk_size, seq_len)
|
|
||||||
|
|
||||||
# Get current chunk
|
|
||||||
x_chunk = x[:, chunk_start:chunk_end]
|
|
||||||
B_chunk = B[:, chunk_start:chunk_end]
|
|
||||||
C_chunk = C[:, chunk_start:chunk_end]
|
|
||||||
dA_chunk = dA[:, chunk_start:chunk_end]
|
|
||||||
z_chunk = z[:, chunk_start:chunk_end]
|
|
||||||
|
|
||||||
# Process the chunk in batches
|
|
||||||
chunk_outputs = []
|
|
||||||
chunk_state = next_state
|
|
||||||
|
|
||||||
for t in range(chunk_end - chunk_start):
|
|
||||||
xt = x_chunk[:, t]
|
|
||||||
Bt = B_chunk[:, t]
|
|
||||||
Ct = C_chunk[:, t]
|
|
||||||
dAt = dA_chunk[:, t]
|
|
||||||
|
|
||||||
# Update state
|
|
||||||
dBx = mx.einsum('bh,bgd,bhp->bhpd', dAt, Bt, xt)
|
|
||||||
chunk_state = chunk_state * mx.expand_dims(dAt, axis=(-1, -2)) + dBx
|
|
||||||
|
|
||||||
# Compute output
|
|
||||||
yt = mx.einsum('bhpd,bgd->bhp', chunk_state, Ct)
|
|
||||||
yt = yt + xt * mx.expand_dims(self.D, -1)
|
|
||||||
|
|
||||||
# Reshape and normalize
|
|
||||||
yt = mx.reshape(yt, (batch_size, 1, self.d_inner))
|
|
||||||
yt = self.norm(yt, z_chunk[:, t:t+1])
|
|
||||||
chunk_outputs.append(self.out_proj(yt))
|
|
||||||
|
|
||||||
# Update state for next chunk
|
|
||||||
next_state = chunk_state
|
|
||||||
outputs.extend(chunk_outputs)
|
|
||||||
|
|
||||||
# Update cache with final state
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
cache[1] = next_state
|
cache[1] = next_state
|
||||||
|
|
||||||
return mx.concatenate(outputs, axis=1)
|
# Apply normalization and output projection
|
||||||
|
y = self.norm(y, z)
|
||||||
|
y = self.out_proj(y)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
@ -238,8 +266,8 @@ class ResidualBlock(nn.Module):
|
|||||||
self.norm = nn.RMSNorm(args.hidden_size)
|
self.norm = nn.RMSNorm(args.hidden_size)
|
||||||
|
|
||||||
def __call__(self, x: mx.array, cache):
|
def __call__(self, x: mx.array, cache):
|
||||||
if self.residual_in_fp32:
|
# if self.residual_in_fp32:
|
||||||
x = x.astype(mx.float32)
|
# x = x.astype(mx.float32)
|
||||||
normed = self.norm(x)
|
normed = self.norm(x)
|
||||||
output = self.mixer(normed, cache)
|
output = self.mixer(normed, cache)
|
||||||
return output + x
|
return output + x
|
||||||
|
Loading…
Reference in New Issue
Block a user