Fixing the Batching Depfwise Comnvolution and multi token input

This commit is contained in:
Goekdeniz-Guelmez
2024-09-04 22:08:32 +02:00
parent e6c96f2b7a
commit 236acb16a8
3 changed files with 149 additions and 702 deletions

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional
import math
@@ -32,6 +33,8 @@ class ModelArgs(BaseModelArgs):
use_cache: bool
use_mambapy: bool = False
dt_rank: str = "auto"
tie_word_embeddings: bool = True
def __post_init__(self):
if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'):
@@ -40,12 +43,6 @@ class ModelArgs(BaseModelArgs):
self.intermediate_size = self.d_inner
if not hasattr(self, 'state_size') and hasattr(self, 'd_state'):
self.state_size = self.d_state
if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'):
self.time_step_min = self.dt_min
if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'):
self.time_step_min = self.dt_max
if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'):
self.time_step_min = self.dt_init_floor
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'):
self.num_hidden_layers = self.n_layer
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'):
@@ -61,6 +58,56 @@ class ModelArgs(BaseModelArgs):
if self.dt_rank == "auto":
self.dt_rank = math.ceil(self.hidden_size / 16)
class DepthWiseConv1d(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int,
bias: bool = True,
padding: int = 0
):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal((channels, 1, kernel_size))
if bias:
self.bias = mx.zeros((channels,))
else:
self.bias = None
def __call__(self, x, cache=None):
B, L, C = x.shape
assert C == self.channels, f"Input channels ({C}) must match the initialized channels ({self.channels})."
w = self.weight # Shape: (C, 1, K)
K = self.kernel_size
total_padding = self.padding + K - 1
if cache is not None:
l = []
if cache.shape[1] < total_padding:
l.append(mx.zeros((B, total_padding - cache.shape[1], C), dtype=x.dtype))
l.extend([cache, x])
x = mx.concatenate(l, axis=1)
else:
x = mx.pad(x, [(0, 0), (total_padding, 0), (0, 0)])
# Manual depthwise convolution
output = []
for i in range(K):
slice = x[:, i:i+L, :]
output.append(slice * w[:, 0, i])
y = mx.sum(mx.stack(output), axis=0)
# The cache is always total_padding
cache = x[:, max(x.shape[1] - total_padding, 0):, :]
if self.bias is not None:
y = y + self.bias.reshape(1, 1, -1)
return y, cache
def clamp(x, min=None, max=None):
if min is not None:
@@ -72,54 +119,6 @@ def clamp(x, min=None, max=None):
return mx.where(mask_upper, max, mx.where(mask_lower, min, x))
return mx.where(mask_lower, min, x)
return mx.where(mask_upper, max, x)
class Conv1d(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int,
bias: bool = True,
padding: int = 0
):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.use_bias = bias
self.padding = padding
# Change the weight initialization to match the expected shape
self.weight = mx.zeros((kernel_size, 1, channels))
if self.use_bias:
self.bias = mx.zeros((channels,))
else:
self.bias = None
def __call__(self, x, cache=None):
# Use the weight directly without transposing
w = self.weight
if cache is not None:
l = []
# Pad the cache if needed
if cache.shape[1] < self.kernel_size - 1:
l.append(
mx.zeros(
(x.shape[0], self.kernel_size - 1 - cache.shape[1], self.channels), dtype=x.dtype
)
)
l.extend([cache, x])
x = mx.concatenate(l, axis=1)
y = mx.conv_general(x, w, padding=([0], [0]), groups=self.channels)
else:
y = mx.conv_general(x, w, padding=([self.padding], [0]), groups=self.channels)
# The cache is always kernel_size - 1
cache = x[:, max(x.shape[1] - self.kernel_size + 1, 0) :, :]
if self.use_bias:
y = y + self.bias
return y, cache
class MambaBlock(nn.Module):
@@ -127,50 +126,46 @@ class MambaBlock(nn.Module):
super().__init__()
self.args = args
# projects block input from D to 2*ED (two branches)
self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias)
self.hidden_size = args.hidden_size
self.ssm_state_size = args.state_size
self.conv_kernel_size = args.conv_kernel
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
# short 1d conv over time
self.conv1d = Conv1d(
channels=args.intermediate_size,
kernel_size=args.conv_kernel,
bias=args.use_conv_bias,
padding=args.conv_kernel-1
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias)
self.conv1d = DepthWiseConv1d(
channels=self.intermediate_size,
kernel_size=self.conv_kernel_size,
bias=self.use_conv_bias,
padding=self.conv_kernel_size-1
)
# projects x to input-dependent Δ, B, C
self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False)
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
# projects Δ from dt_rank to intermediate_size
self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True)
# dt initialization
# dt weights
dt_init_std = args.dt_rank**-0.5 * args.state_size
dt_init_std = args.time_step_rank**-0.5 * args.state_size
if args.time_step_init_scheme == "constant":
self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight)
elif args.time_step_init_scheme == "random":
self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape)
else:
raise NotImplementedError
# dt bias
dt = clamp(mx.exp(
mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)
), min=args.time_step_floor)
inv_dt = dt + mx.log1p(-mx.exp(-dt))
self.dt_proj.bias = inv_dt
# S4D real initialization
A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0)
self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
self.D = mx.ones([args.intermediate_size])
A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0)
self.A_log = mx.log(A)
self.D = mx.ones([self.intermediate_size])
# projects block output from ED back to D
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm(self, x, h):
def ssm_step(self, x, h):
# x : (B, ED)
# h : (B, ED, N)
@@ -182,7 +177,7 @@ class MambaBlock(nn.Module):
deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N)
delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N)
delta = nn.softplus(self.dt_proj(delta)) # (B, ED)
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N)
@@ -191,51 +186,55 @@ class MambaBlock(nn.Module):
BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N)
if h is None:
h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N)
h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N)
h = deltaA * h + BX # (B, ED, N)
y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
y = y + D * x
return y, h
def __call__(self, x, cache):
# x : (B, D)
# cache : (h, inputs)
# h : (B, ED, N)
# inputs : (B, conv_kernel-1, ED)
# y : (B, D)
# cache : (h, inputs)
# x : (B, T, D) where T is the number of tokens (5 in this case)
# cache : (h, inputs)
# h : (B, ED, N)
# inputs : (B, d_conv-1, ED)
h, inputs = cache
print("Input shape:", x.shape)
xz = self.in_proj(x) # (B, 2*ED)
xz = xz.reshape(x.shape[0], -1) # Ensure shape is (B, 2*ED)
print("After in_proj shape:", xz.shape)
x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED)
B, T, D = x.shape
# x branch
x_cache = mx.expand_dims(x, 1)
x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED)
outputs = []
for t in range(T):
xt = x[:, t, :] # (B, D)
xz = self.in_proj(xt) # (B, 2*ED)
x_t, z_t = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED)
x = nn.silu(x)
y, h = self.ssm_step(x, h)
# x branch
x_cache = mx.expand_dims(x_t, 1) # (B, 1, ED)
conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED)
conv_out, new_inputs = self.conv1d(conv_input) # (B, d_conv, ED), (B, d_conv-1, ED)
x_t = conv_out[:, -1, :] # (B, ED)
# z branch
z = nn.silu(z)
x_t = nn.silu(x_t)
y_t, h = self.ssm_step(x_t, h)
output = y * z
output = self.out_proj(output) # (B, D)
# z branch
z_t = nn.silu(z_t)
# prepare cache for next call
inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED)
output_t = y_t * z_t
output_t = self.out_proj(output_t) # (B, D)
outputs.append(output_t)
# Update inputs for next token
inputs = new_inputs
output = mx.stack(outputs, axis=1) # (B, T, D)
cache = (h, inputs)
return output, cache
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
@@ -243,19 +242,12 @@ class ResidualBlock(nn.Module):
self.mixer = MambaBlock(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, inputs: mx.array, cache):
# x : (B, D)
# cache : (h, inputs)
# h : (B, ED, N)
# inputs: (B, conv_kernel-1, ED)
# output : (B, D)
# cache : (h, inputs)
output, cache = self.mixer(self.norm(inputs), cache)
output = output + inputs
def __call__(self, x: mx.array, cache):
output, cache = self.mixer(self.norm(x), cache)
output = output + x
return output, cache
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@@ -263,22 +255,11 @@ class Mamba(nn.Module):
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size)
def __call__(self, tokens: mx.array, caches):
# tokens : (B, L)
# logits : (B, L, vocab_size)
x = self.embeddings(tokens)
# x : (B, L, D)
# caches : [cache(layer) for all layers], cache : (h, inputs)
# y : (B, L, D)
# caches : [cache(layer) for all layers], cache : (h, inputs)
def __call__(self, x: mx.array, caches):
x = self.embeddings(x)
print(x.shape)
for i, layer in enumerate(self.layers):
x, caches[i] = layer(x, caches[i])
return x, caches
@@ -289,10 +270,39 @@ class Model(nn.Module):
self.model_type = args.model_type
self.backbone = Mamba(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
out, cache = self.backbone(inputs, cache)
# out = self.backbone.embeddings.as_linear(out)
return out, cache
# inputs : (B, T) where T is the number of tokens
# caches : [cache(layer) for all layers], cache : (h, inputs)
if inputs.ndim == 1:
inputs = mx.expand_dims(inputs, 0) # Add batch dimension if not present
B, T = inputs.shape
x = self.backbone.embeddings(inputs) # (B, T, D)
for i, layer in enumerate(self.backbone.layers):
x, cache[i] = layer(x, cache[i])
x = self.backbone.norm_f(x)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
print(f"Logits shape: {logits.shape}")
# logits : (B, T, vocab_size)
print(logits)
return logits, cache
def make_cache(self):
B = 1 # Assuming batch size of 1 for simplicity
return [(None, mx.zeros((B, self.args.conv_kernel-1, self.args.intermediate_size)))
for _ in range(self.args.num_hidden_layers)]
@property
def layers(self):
@@ -306,17 +316,5 @@ class Model(nn.Module):
def n_kv_heads(self):
return self.args.num_hidden_layers
def make_cache(self):
return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)]
def sanitize(self, weights):
for key, value in weights.items():
if "mixer.conv1d.weight" in key:
# Ensure the weight is in the shape (kernel_size, 1, channels)
if value.shape != (self.args.conv_kernel, 1, self.args.intermediate_size):
weights[key] = value.reshape(self.args.conv_kernel, 1, self.args.intermediate_size)
elif key == "backbone.embeddings.weight":
# Ensure the embedding weight is in the shape (vocab_size, hidden_size)
if value.shape != (self.args.vocab_size, self.args.hidden_size):
weights[key] = value.T
return weights
# def make_cache(self):
# return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)]

View File

@@ -1,293 +0,0 @@
from dataclasses import dataclass
import math
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
time_step_rank: int
time_step_scale: float
time_step_min: float
time_step_max: float
time_step_init_scheme: str
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
use_mambapy: bool = False
dt_rank: str = "auto"
def __post_init__(self):
if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'):
self.hidden_size = self.d_model
if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'):
self.intermediate_size = self.d_inner
if not hasattr(self, 'state_size') and hasattr(self, 'd_state'):
self.state_size = self.d_state
if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'):
self.time_step_min = self.dt_min
if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'):
self.time_step_min = self.dt_max
if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'):
self.time_step_min = self.dt_init_floor
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'):
self.num_hidden_layers = self.n_layer
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'):
self.num_hidden_layers = self.n_layers
if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'):
self.conv_kernel = self.d_conv
if not hasattr(self, 'use_bias') and hasattr(self, 'bias'):
self.use_bias = self.bias
if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'):
self.use_conv_bias = self.conv_bias
self.intermediate_size = self.expand * self.hidden_size
if self.dt_rank == "auto":
self.dt_rank = math.ceil(self.hidden_size / 16)
def clamp(x, min=None, max=None):
if min is not None:
mask_lower = x < min
if max is not None:
mask_upper = x > max
if min is not None:
if max is not None:
return mx.where(mask_upper, max, mx.where(mask_lower, min, x))
return mx.where(mask_lower, min, x)
return mx.where(mask_upper, max, x)
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias, padding):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.bias = bias
self.padding = padding
self.conv1d = nn.Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
bias=True,
padding=padding
)
indices = mx.arange(channels)
mask = mx.zeros_like(self.conv1d.weight)
mask[indices, :, indices] = 1
self.conv1d.weight *= mask
def __call__(self, x):
return self.conv1d(x)
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
# projects block input from D to 2*ED (two branches)
self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias)
# short 1d conv over time
self.conv1d = DepthWiseConv1d(
channels=args.intermediate_size,
kernel_size=args.conv_kernel,
bias=args.use_conv_bias,
padding=args.conv_kernel-1
)
# projects x to input-dependent Δ, B, C
self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False)
# projects Δ from dt_rank to intermediate_size
self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True)
# dt initialization
# dt weights
dt_init_std = args.dt_rank**-0.5 * args.state_size
if args.time_step_init_scheme == "constant":
self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight)
elif args.time_step_init_scheme == "random":
self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape)
else:
raise NotImplementedError
# dt bias
dt = clamp(mx.exp(
mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)
), min=args.time_step_floor)
inv_dt = dt + mx.log1p(-mx.exp(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
self.dt_proj.bias = inv_dt
# S4D real initialization
A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0)
self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
self.D = mx.ones([args.intermediate_size])
# projects block output from ED back to D
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
def ssm(self, x, h):
# x : (B, ED)
# h : (B, ED, N)
# y : (B, ED)
# h : (B, ED, N)
A = -mx.exp(self.A_log) # (ED, N) # todo : move out of step (timestep independent)
D = self.D
deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N)
delta = nn.softplus(self.dt_proj(delta)) # (B, ED)
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N)
deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N)
BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N)
if h is None:
h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N)
h = deltaA * h + BX # (B, ED, N)
y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
y = y + D * x
return y, h
def __call__(self, x, cache):
# x : (B, D)
# cache : (h, inputs)
# h : (B, ED, N)
# inputs : (B, conv_kernel-1, ED)
# y : (B, D)
# cache : (h, inputs)
h, inputs = cache
xz = self.in_proj(x) # (B, 2*ED)
x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED)
# x branch
x_cache = mx.expand_dims(x, 1)
x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED)
x = nn.silu(x)
y, h = self.ssm_step(x, h)
# z branch
z = nn.silu(z)
output = y * z
output = self.out_proj(output) # (B, D)
# prepare cache for next call
inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED)
cache = (h, inputs)
return output, cache
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = MambaBlock(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, inputs: mx.array, cache):
# x : (B, D)
# cache : (h, inputs)
# h : (B, ED, N)
# inputs: (B, conv_kernel-1, ED)
# output : (B, D)
# cache : (h, inputs)
output, cache = self.mixer(self.norm(inputs), cache)
output = output + inputs
return output, cache
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size)
def __call__(self, tokens: mx.array, caches):
# tokens : (B, L)
# logits : (B, L, vocab_size)
x = self.embeddings(tokens)
# x : (B, L, D)
# caches : [cache(layer) for all layers], cache : (h, inputs)
# y : (B, L, D)
# caches : [cache(layer) for all layers], cache : (h, inputs)
for i, layer in enumerate(self.layers):
x, caches[i] = layer(x, caches[i])
return x, caches
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba(args)
def __call__(self, inputs: mx.array, cache=None):
out, cache = self.backbone(inputs, cache)
# out = self.backbone.embeddings.as_linear(out)
return out, cache
@property
def layers(self):
return self.backbone.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_hidden_layers
@property
def n_kv_heads(self):
return self.args.num_hidden_layers
def make_cache(self):
# return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)]
return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))]
def sanitize(self, weights):
new_weights = {}
for key, value in weights.items():
if "mixer.conv1d.weight" in key:
weights[key] = value.T
new_key = key.replace('mixer.conv1d', 'mixer.conv1d.conv1d')
new_weights[new_key] = value
return new_weights

View File

@@ -1,258 +0,0 @@
from dataclasses import dataclass
import math
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
time_step_rank: int
time_step_scale: float
time_step_min: float
time_step_max: float
time_step_init_scheme: str
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
use_mambapy: bool = False
dt_rank: str = "auto"
def __post_init__(self):
if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'):
self.hidden_size = self.d_model
if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'):
self.intermediate_size = self.d_inner
if not hasattr(self, 'state_size') and hasattr(self, 'd_state'):
self.state_size = self.d_state
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'):
self.num_hidden_layers = self.n_layer
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'):
self.num_hidden_layers = self.n_layers
if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'):
self.conv_kernel = self.d_conv
if not hasattr(self, 'use_bias') and hasattr(self, 'bias'):
self.use_bias = self.bias
if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'):
self.use_conv_bias = self.conv_bias
self.intermediate_size = self.expand * self.hidden_size
if self.dt_rank == "auto":
self.dt_rank = math.ceil(self.hidden_size / 16)
def clamp(x, min=None, max=None):
if min is not None:
mask_lower = x < min
if max is not None:
mask_upper = x > max
if min is not None:
if max is not None:
return mx.where(mask_upper, max, mx.where(mask_lower, min, x))
return mx.where(mask_lower, min, x)
return mx.where(mask_upper, max, x)
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias, padding):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal(shape=(channels, 1, kernel_size))
scale = math.sqrt(1.0 / (channels * kernel_size))
self.weight *= scale
if bias:
self.bias = mx.zeros((channels,))
else:
self.bias = None
def __call__(self, x):
# x shape is (B, C, L)
B, C, L = x.shape
# Pad the input
if self.padding > 0:
padding = [(0, 0), (0, 0), (self.padding, self.padding)]
x_padded = mx.pad(x, padding)
else:
x_padded = x
# Perform depthwise convolution manually
out = []
for i in range(L):
slice = x_padded[:, :, i:i+self.kernel_size]
out.append(mx.sum(slice * self.weight, axis=2))
out = mx.stack(out, axis=2)
# Apply bias if present
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1)
return out
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.ssm_state_size = args.state_size
self.conv_kernel_size = args.conv_kernel
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias)
self.conv1d = DepthWiseConv1d(
channels=int(self.intermediate_size),
kernel_size=int(self.conv_kernel_size),
bias=self.use_conv_bias,
padding=int(self.conv_kernel_size - 1)
)
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
dt_init_std = args.dt_rank**-0.5 * args.state_size
if args.time_step_init_scheme == "constant":
self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight)
elif args.time_step_init_scheme == "random":
self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape)
else:
raise NotImplementedError
dt = clamp(mx.exp(
mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)
), min=args.time_step_floor)
inv_dt = dt + mx.log1p(-mx.exp(-dt))
self.dt_proj.bias = inv_dt
A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0)
self.A_log = mx.log(A)
self.D = mx.ones([self.intermediate_size])
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm(self, x, h):
A = -mx.exp(self.A_log) # (ED, N)
D = self.D
deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N)
delta = nn.softplus(self.dt_proj(delta)) # (B, ED)
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N)
deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N)
BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N)
if h is None:
h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N)
h = deltaA * h + BX # (B, ED, N)
y = mx.sum(h * mx.expand_dims(C, 1), axis=-1) # (B, ED)
y = y + D * x
return y, h
def __call__(self, x, cache):
h, inputs = cache
x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1)
# x is now (B, L, C), we need (B, C, L) for conv1d
x_cache = x.transpose(0, 2, 1)
if inputs is None:
inputs = mx.zeros((x.shape[0], self.intermediate_size, self.conv_kernel_size - 1))
else:
inputs = inputs.transpose(0, 2, 1) # Change to (batch, channels, sequence)
conv_input = mx.concatenate([inputs, x_cache], axis=2)
x = self.conv1d(conv_input)
x = x[:, :, -1] # Take the last element of the sequence
y, h = self.ssm(x, h)
output = y * nn.silu(z[:, -1, :])
# Update inputs for the next iteration
inputs = conv_input[:, :, 1:]
return self.out_proj(output), (h, inputs.transpose(0, 2, 1))
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = MambaBlock(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, inputs: mx.array, cache):
output, cache = self.mixer(self.norm(inputs), cache)
output = output + inputs[:, -1, :] # Add residual only for the last time step
return output, cache
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size)
def __call__(self, inputs: mx.array, cache):
tokens = self.embeddings(inputs)
for i, layer in enumerate(self.layers):
h, cache[i] = layer(tokens, cache[i])
h = self.norm_f(h)
return h, cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba(args)
def __call__(self, inputs: mx.array, cache=None):
out, cache = self.backbone(inputs, cache)
out = self.backbone.embeddings.as_linear(out)
return out, cache
@property
def layers(self):
return self.backbone.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_hidden_layers
@property
def n_kv_heads(self):
return self.args.num_hidden_layers
def make_cache(self):
# return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)]
return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))]