mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Fixing the Batching Depfwise Comnvolution and multi token input
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import math
|
||||
|
||||
@@ -32,6 +33,8 @@ class ModelArgs(BaseModelArgs):
|
||||
use_cache: bool
|
||||
use_mambapy: bool = False
|
||||
dt_rank: str = "auto"
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'):
|
||||
@@ -40,12 +43,6 @@ class ModelArgs(BaseModelArgs):
|
||||
self.intermediate_size = self.d_inner
|
||||
if not hasattr(self, 'state_size') and hasattr(self, 'd_state'):
|
||||
self.state_size = self.d_state
|
||||
if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'):
|
||||
self.time_step_min = self.dt_min
|
||||
if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'):
|
||||
self.time_step_min = self.dt_max
|
||||
if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'):
|
||||
self.time_step_min = self.dt_init_floor
|
||||
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'):
|
||||
self.num_hidden_layers = self.n_layer
|
||||
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'):
|
||||
@@ -61,6 +58,56 @@ class ModelArgs(BaseModelArgs):
|
||||
if self.dt_rank == "auto":
|
||||
self.dt_rank = math.ceil(self.hidden_size / 16)
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
bias: bool = True,
|
||||
padding: int = 0
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.weight = mx.random.normal((channels, 1, kernel_size))
|
||||
if bias:
|
||||
self.bias = mx.zeros((channels,))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
B, L, C = x.shape
|
||||
assert C == self.channels, f"Input channels ({C}) must match the initialized channels ({self.channels})."
|
||||
|
||||
w = self.weight # Shape: (C, 1, K)
|
||||
K = self.kernel_size
|
||||
total_padding = self.padding + K - 1
|
||||
|
||||
if cache is not None:
|
||||
l = []
|
||||
if cache.shape[1] < total_padding:
|
||||
l.append(mx.zeros((B, total_padding - cache.shape[1], C), dtype=x.dtype))
|
||||
l.extend([cache, x])
|
||||
x = mx.concatenate(l, axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (total_padding, 0), (0, 0)])
|
||||
|
||||
# Manual depthwise convolution
|
||||
output = []
|
||||
for i in range(K):
|
||||
slice = x[:, i:i+L, :]
|
||||
output.append(slice * w[:, 0, i])
|
||||
y = mx.sum(mx.stack(output), axis=0)
|
||||
|
||||
# The cache is always total_padding
|
||||
cache = x[:, max(x.shape[1] - total_padding, 0):, :]
|
||||
|
||||
if self.bias is not None:
|
||||
y = y + self.bias.reshape(1, 1, -1)
|
||||
|
||||
return y, cache
|
||||
|
||||
|
||||
def clamp(x, min=None, max=None):
|
||||
if min is not None:
|
||||
@@ -72,54 +119,6 @@ def clamp(x, min=None, max=None):
|
||||
return mx.where(mask_upper, max, mx.where(mask_lower, min, x))
|
||||
return mx.where(mask_lower, min, x)
|
||||
return mx.where(mask_upper, max, x)
|
||||
|
||||
|
||||
class Conv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
bias: bool = True,
|
||||
padding: int = 0
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.use_bias = bias
|
||||
self.padding = padding
|
||||
|
||||
# Change the weight initialization to match the expected shape
|
||||
self.weight = mx.zeros((kernel_size, 1, channels))
|
||||
if self.use_bias:
|
||||
self.bias = mx.zeros((channels,))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
# Use the weight directly without transposing
|
||||
w = self.weight
|
||||
if cache is not None:
|
||||
l = []
|
||||
# Pad the cache if needed
|
||||
if cache.shape[1] < self.kernel_size - 1:
|
||||
l.append(
|
||||
mx.zeros(
|
||||
(x.shape[0], self.kernel_size - 1 - cache.shape[1], self.channels), dtype=x.dtype
|
||||
)
|
||||
)
|
||||
l.extend([cache, x])
|
||||
x = mx.concatenate(l, axis=1)
|
||||
y = mx.conv_general(x, w, padding=([0], [0]), groups=self.channels)
|
||||
else:
|
||||
y = mx.conv_general(x, w, padding=([self.padding], [0]), groups=self.channels)
|
||||
|
||||
# The cache is always kernel_size - 1
|
||||
cache = x[:, max(x.shape[1] - self.kernel_size + 1, 0) :, :]
|
||||
|
||||
if self.use_bias:
|
||||
y = y + self.bias
|
||||
|
||||
return y, cache
|
||||
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
@@ -127,50 +126,46 @@ class MambaBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
# projects block input from D to 2*ED (two branches)
|
||||
self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias)
|
||||
self.hidden_size = args.hidden_size
|
||||
self.ssm_state_size = args.state_size
|
||||
self.conv_kernel_size = args.conv_kernel
|
||||
self.intermediate_size = args.intermediate_size
|
||||
self.time_step_rank = int(args.time_step_rank)
|
||||
self.use_conv_bias = args.use_conv_bias
|
||||
|
||||
# short 1d conv over time
|
||||
self.conv1d = Conv1d(
|
||||
channels=args.intermediate_size,
|
||||
kernel_size=args.conv_kernel,
|
||||
bias=args.use_conv_bias,
|
||||
padding=args.conv_kernel-1
|
||||
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias)
|
||||
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
channels=self.intermediate_size,
|
||||
kernel_size=self.conv_kernel_size,
|
||||
bias=self.use_conv_bias,
|
||||
padding=self.conv_kernel_size-1
|
||||
)
|
||||
|
||||
# projects x to input-dependent Δ, B, C
|
||||
self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False)
|
||||
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False)
|
||||
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
||||
|
||||
# projects Δ from dt_rank to intermediate_size
|
||||
self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True)
|
||||
|
||||
# dt initialization
|
||||
# dt weights
|
||||
dt_init_std = args.dt_rank**-0.5 * args.state_size
|
||||
|
||||
dt_init_std = args.time_step_rank**-0.5 * args.state_size
|
||||
if args.time_step_init_scheme == "constant":
|
||||
self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight)
|
||||
elif args.time_step_init_scheme == "random":
|
||||
self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# dt bias
|
||||
|
||||
dt = clamp(mx.exp(
|
||||
mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)
|
||||
), min=args.time_step_floor)
|
||||
inv_dt = dt + mx.log1p(-mx.exp(-dt))
|
||||
self.dt_proj.bias = inv_dt
|
||||
|
||||
# S4D real initialization
|
||||
A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0)
|
||||
self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
|
||||
self.D = mx.ones([args.intermediate_size])
|
||||
A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0)
|
||||
self.A_log = mx.log(A)
|
||||
self.D = mx.ones([self.intermediate_size])
|
||||
|
||||
# projects block output from ED back to D
|
||||
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
|
||||
|
||||
def ssm(self, x, h):
|
||||
def ssm_step(self, x, h):
|
||||
# x : (B, ED)
|
||||
# h : (B, ED, N)
|
||||
|
||||
@@ -182,7 +177,7 @@ class MambaBlock(nn.Module):
|
||||
|
||||
deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
|
||||
|
||||
delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N)
|
||||
delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N)
|
||||
delta = nn.softplus(self.dt_proj(delta)) # (B, ED)
|
||||
|
||||
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N)
|
||||
@@ -191,51 +186,55 @@ class MambaBlock(nn.Module):
|
||||
BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N)
|
||||
|
||||
if h is None:
|
||||
h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N)
|
||||
h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N)
|
||||
|
||||
h = deltaA * h + BX # (B, ED, N)
|
||||
|
||||
y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
|
||||
|
||||
y = y + D * x
|
||||
|
||||
|
||||
return y, h
|
||||
|
||||
def __call__(self, x, cache):
|
||||
# x : (B, D)
|
||||
# cache : (h, inputs)
|
||||
# h : (B, ED, N)
|
||||
# inputs : (B, conv_kernel-1, ED)
|
||||
|
||||
# y : (B, D)
|
||||
# cache : (h, inputs)
|
||||
|
||||
# x : (B, T, D) where T is the number of tokens (5 in this case)
|
||||
# cache : (h, inputs)
|
||||
# h : (B, ED, N)
|
||||
# inputs : (B, d_conv-1, ED)
|
||||
|
||||
h, inputs = cache
|
||||
|
||||
print("Input shape:", x.shape)
|
||||
xz = self.in_proj(x) # (B, 2*ED)
|
||||
xz = xz.reshape(x.shape[0], -1) # Ensure shape is (B, 2*ED)
|
||||
print("After in_proj shape:", xz.shape)
|
||||
x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED)
|
||||
B, T, D = x.shape
|
||||
|
||||
# x branch
|
||||
x_cache = mx.expand_dims(x, 1)
|
||||
x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED)
|
||||
outputs = []
|
||||
for t in range(T):
|
||||
xt = x[:, t, :] # (B, D)
|
||||
xz = self.in_proj(xt) # (B, 2*ED)
|
||||
x_t, z_t = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED)
|
||||
|
||||
x = nn.silu(x)
|
||||
y, h = self.ssm_step(x, h)
|
||||
# x branch
|
||||
x_cache = mx.expand_dims(x_t, 1) # (B, 1, ED)
|
||||
conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED)
|
||||
conv_out, new_inputs = self.conv1d(conv_input) # (B, d_conv, ED), (B, d_conv-1, ED)
|
||||
x_t = conv_out[:, -1, :] # (B, ED)
|
||||
|
||||
# z branch
|
||||
z = nn.silu(z)
|
||||
x_t = nn.silu(x_t)
|
||||
y_t, h = self.ssm_step(x_t, h)
|
||||
|
||||
output = y * z
|
||||
output = self.out_proj(output) # (B, D)
|
||||
# z branch
|
||||
z_t = nn.silu(z_t)
|
||||
|
||||
# prepare cache for next call
|
||||
inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED)
|
||||
output_t = y_t * z_t
|
||||
output_t = self.out_proj(output_t) # (B, D)
|
||||
outputs.append(output_t)
|
||||
|
||||
# Update inputs for next token
|
||||
inputs = new_inputs
|
||||
|
||||
output = mx.stack(outputs, axis=1) # (B, T, D)
|
||||
cache = (h, inputs)
|
||||
|
||||
|
||||
return output, cache
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
@@ -243,19 +242,12 @@ class ResidualBlock(nn.Module):
|
||||
self.mixer = MambaBlock(args)
|
||||
self.norm = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache):
|
||||
# x : (B, D)
|
||||
# cache : (h, inputs)
|
||||
# h : (B, ED, N)
|
||||
# inputs: (B, conv_kernel-1, ED)
|
||||
|
||||
# output : (B, D)
|
||||
# cache : (h, inputs)
|
||||
|
||||
output, cache = self.mixer(self.norm(inputs), cache)
|
||||
output = output + inputs
|
||||
def __call__(self, x: mx.array, cache):
|
||||
output, cache = self.mixer(self.norm(x), cache)
|
||||
output = output + x
|
||||
return output, cache
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@@ -263,22 +255,11 @@ class Mamba(nn.Module):
|
||||
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm_f = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, tokens: mx.array, caches):
|
||||
# tokens : (B, L)
|
||||
|
||||
# logits : (B, L, vocab_size)
|
||||
|
||||
x = self.embeddings(tokens)
|
||||
|
||||
# x : (B, L, D)
|
||||
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
||||
|
||||
# y : (B, L, D)
|
||||
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
||||
|
||||
def __call__(self, x: mx.array, caches):
|
||||
x = self.embeddings(x)
|
||||
print(x.shape)
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, caches[i] = layer(x, caches[i])
|
||||
|
||||
return x, caches
|
||||
|
||||
|
||||
@@ -289,10 +270,39 @@ class Model(nn.Module):
|
||||
self.model_type = args.model_type
|
||||
self.backbone = Mamba(args)
|
||||
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache=None):
|
||||
out, cache = self.backbone(inputs, cache)
|
||||
# out = self.backbone.embeddings.as_linear(out)
|
||||
return out, cache
|
||||
# inputs : (B, T) where T is the number of tokens
|
||||
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
||||
|
||||
if inputs.ndim == 1:
|
||||
inputs = mx.expand_dims(inputs, 0) # Add batch dimension if not present
|
||||
|
||||
B, T = inputs.shape
|
||||
x = self.backbone.embeddings(inputs) # (B, T, D)
|
||||
|
||||
for i, layer in enumerate(self.backbone.layers):
|
||||
x, cache[i] = layer(x, cache[i])
|
||||
|
||||
x = self.backbone.norm_f(x)
|
||||
|
||||
if self.args.tie_word_embeddings:
|
||||
logits = self.backbone.embeddings.as_linear(x)
|
||||
else:
|
||||
logits = self.lm_head(x)
|
||||
|
||||
print(f"Logits shape: {logits.shape}")
|
||||
# logits : (B, T, vocab_size)
|
||||
print(logits)
|
||||
|
||||
return logits, cache
|
||||
|
||||
def make_cache(self):
|
||||
B = 1 # Assuming batch size of 1 for simplicity
|
||||
return [(None, mx.zeros((B, self.args.conv_kernel-1, self.args.intermediate_size)))
|
||||
for _ in range(self.args.num_hidden_layers)]
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
@@ -306,17 +316,5 @@ class Model(nn.Module):
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_hidden_layers
|
||||
|
||||
def make_cache(self):
|
||||
return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)]
|
||||
|
||||
def sanitize(self, weights):
|
||||
for key, value in weights.items():
|
||||
if "mixer.conv1d.weight" in key:
|
||||
# Ensure the weight is in the shape (kernel_size, 1, channels)
|
||||
if value.shape != (self.args.conv_kernel, 1, self.args.intermediate_size):
|
||||
weights[key] = value.reshape(self.args.conv_kernel, 1, self.args.intermediate_size)
|
||||
elif key == "backbone.embeddings.weight":
|
||||
# Ensure the embedding weight is in the shape (vocab_size, hidden_size)
|
||||
if value.shape != (self.args.vocab_size, self.args.hidden_size):
|
||||
weights[key] = value.T
|
||||
return weights
|
||||
# def make_cache(self):
|
||||
# return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)]
|
@@ -1,293 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
state_size: int
|
||||
num_hidden_layers: int
|
||||
layer_norm_epsilon: float
|
||||
expand: int
|
||||
conv_kernel: int
|
||||
use_bias: bool
|
||||
use_conv_bias: bool
|
||||
initializer_range: float
|
||||
time_step_rank: int
|
||||
time_step_scale: float
|
||||
time_step_min: float
|
||||
time_step_max: float
|
||||
time_step_init_scheme: str
|
||||
time_step_floor: float
|
||||
rescale_prenorm_residual: bool
|
||||
use_cache: bool
|
||||
use_mambapy: bool = False
|
||||
dt_rank: str = "auto"
|
||||
|
||||
def __post_init__(self):
|
||||
if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'):
|
||||
self.hidden_size = self.d_model
|
||||
if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'):
|
||||
self.intermediate_size = self.d_inner
|
||||
if not hasattr(self, 'state_size') and hasattr(self, 'd_state'):
|
||||
self.state_size = self.d_state
|
||||
if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'):
|
||||
self.time_step_min = self.dt_min
|
||||
if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'):
|
||||
self.time_step_min = self.dt_max
|
||||
if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'):
|
||||
self.time_step_min = self.dt_init_floor
|
||||
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'):
|
||||
self.num_hidden_layers = self.n_layer
|
||||
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'):
|
||||
self.num_hidden_layers = self.n_layers
|
||||
if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'):
|
||||
self.conv_kernel = self.d_conv
|
||||
if not hasattr(self, 'use_bias') and hasattr(self, 'bias'):
|
||||
self.use_bias = self.bias
|
||||
if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'):
|
||||
self.use_conv_bias = self.conv_bias
|
||||
|
||||
self.intermediate_size = self.expand * self.hidden_size
|
||||
if self.dt_rank == "auto":
|
||||
self.dt_rank = math.ceil(self.hidden_size / 16)
|
||||
|
||||
|
||||
def clamp(x, min=None, max=None):
|
||||
if min is not None:
|
||||
mask_lower = x < min
|
||||
if max is not None:
|
||||
mask_upper = x > max
|
||||
if min is not None:
|
||||
if max is not None:
|
||||
return mx.where(mask_upper, max, mx.where(mask_lower, min, x))
|
||||
return mx.where(mask_lower, min, x)
|
||||
return mx.where(mask_upper, max, x)
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias, padding):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.bias = bias
|
||||
self.padding = padding
|
||||
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=kernel_size,
|
||||
bias=True,
|
||||
padding=padding
|
||||
)
|
||||
indices = mx.arange(channels)
|
||||
mask = mx.zeros_like(self.conv1d.weight)
|
||||
mask[indices, :, indices] = 1
|
||||
self.conv1d.weight *= mask
|
||||
|
||||
def __call__(self, x):
|
||||
return self.conv1d(x)
|
||||
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
# projects block input from D to 2*ED (two branches)
|
||||
self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias)
|
||||
|
||||
# short 1d conv over time
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
channels=args.intermediate_size,
|
||||
kernel_size=args.conv_kernel,
|
||||
bias=args.use_conv_bias,
|
||||
padding=args.conv_kernel-1
|
||||
)
|
||||
|
||||
# projects x to input-dependent Δ, B, C
|
||||
self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False)
|
||||
|
||||
# projects Δ from dt_rank to intermediate_size
|
||||
self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True)
|
||||
|
||||
# dt initialization
|
||||
# dt weights
|
||||
dt_init_std = args.dt_rank**-0.5 * args.state_size
|
||||
|
||||
if args.time_step_init_scheme == "constant":
|
||||
self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight)
|
||||
elif args.time_step_init_scheme == "random":
|
||||
self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# dt bias
|
||||
dt = clamp(mx.exp(
|
||||
mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)
|
||||
), min=args.time_step_floor)
|
||||
inv_dt = dt + mx.log1p(-mx.exp(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||||
self.dt_proj.bias = inv_dt
|
||||
|
||||
# S4D real initialization
|
||||
A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0)
|
||||
self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
|
||||
self.D = mx.ones([args.intermediate_size])
|
||||
|
||||
# projects block output from ED back to D
|
||||
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
|
||||
|
||||
def ssm(self, x, h):
|
||||
# x : (B, ED)
|
||||
# h : (B, ED, N)
|
||||
|
||||
# y : (B, ED)
|
||||
# h : (B, ED, N)
|
||||
|
||||
A = -mx.exp(self.A_log) # (ED, N) # todo : move out of step (timestep independent)
|
||||
D = self.D
|
||||
|
||||
deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
|
||||
|
||||
delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N)
|
||||
delta = nn.softplus(self.dt_proj(delta)) # (B, ED)
|
||||
|
||||
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N)
|
||||
deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N)
|
||||
|
||||
BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N)
|
||||
|
||||
if h is None:
|
||||
h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N)
|
||||
|
||||
h = deltaA * h + BX # (B, ED, N)
|
||||
|
||||
y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
|
||||
|
||||
y = y + D * x
|
||||
|
||||
return y, h
|
||||
|
||||
def __call__(self, x, cache):
|
||||
# x : (B, D)
|
||||
# cache : (h, inputs)
|
||||
# h : (B, ED, N)
|
||||
# inputs : (B, conv_kernel-1, ED)
|
||||
|
||||
# y : (B, D)
|
||||
# cache : (h, inputs)
|
||||
|
||||
h, inputs = cache
|
||||
|
||||
xz = self.in_proj(x) # (B, 2*ED)
|
||||
x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED)
|
||||
|
||||
# x branch
|
||||
x_cache = mx.expand_dims(x, 1)
|
||||
x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED)
|
||||
|
||||
x = nn.silu(x)
|
||||
y, h = self.ssm_step(x, h)
|
||||
|
||||
# z branch
|
||||
z = nn.silu(z)
|
||||
|
||||
output = y * z
|
||||
output = self.out_proj(output) # (B, D)
|
||||
|
||||
# prepare cache for next call
|
||||
inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED)
|
||||
cache = (h, inputs)
|
||||
|
||||
return output, cache
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.mixer = MambaBlock(args)
|
||||
self.norm = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache):
|
||||
# x : (B, D)
|
||||
# cache : (h, inputs)
|
||||
# h : (B, ED, N)
|
||||
# inputs: (B, conv_kernel-1, ED)
|
||||
|
||||
# output : (B, D)
|
||||
# cache : (h, inputs)
|
||||
|
||||
output, cache = self.mixer(self.norm(inputs), cache)
|
||||
output = output + inputs
|
||||
return output, cache
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm_f = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, tokens: mx.array, caches):
|
||||
# tokens : (B, L)
|
||||
|
||||
# logits : (B, L, vocab_size)
|
||||
|
||||
x = self.embeddings(tokens)
|
||||
|
||||
# x : (B, L, D)
|
||||
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
||||
|
||||
# y : (B, L, D)
|
||||
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, caches[i] = layer(x, caches[i])
|
||||
|
||||
return x, caches
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.backbone = Mamba(args)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache=None):
|
||||
out, cache = self.backbone(inputs, cache)
|
||||
# out = self.backbone.embeddings.as_linear(out)
|
||||
return out, cache
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.backbone.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_hidden_layers
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_hidden_layers
|
||||
|
||||
def make_cache(self):
|
||||
# return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)]
|
||||
return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))]
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for key, value in weights.items():
|
||||
if "mixer.conv1d.weight" in key:
|
||||
weights[key] = value.T
|
||||
new_key = key.replace('mixer.conv1d', 'mixer.conv1d.conv1d')
|
||||
new_weights[new_key] = value
|
||||
return new_weights
|
@@ -1,258 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
state_size: int
|
||||
num_hidden_layers: int
|
||||
layer_norm_epsilon: float
|
||||
expand: int
|
||||
conv_kernel: int
|
||||
use_bias: bool
|
||||
use_conv_bias: bool
|
||||
initializer_range: float
|
||||
time_step_rank: int
|
||||
time_step_scale: float
|
||||
time_step_min: float
|
||||
time_step_max: float
|
||||
time_step_init_scheme: str
|
||||
time_step_floor: float
|
||||
rescale_prenorm_residual: bool
|
||||
use_cache: bool
|
||||
use_mambapy: bool = False
|
||||
dt_rank: str = "auto"
|
||||
|
||||
def __post_init__(self):
|
||||
if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'):
|
||||
self.hidden_size = self.d_model
|
||||
if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'):
|
||||
self.intermediate_size = self.d_inner
|
||||
if not hasattr(self, 'state_size') and hasattr(self, 'd_state'):
|
||||
self.state_size = self.d_state
|
||||
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'):
|
||||
self.num_hidden_layers = self.n_layer
|
||||
if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'):
|
||||
self.num_hidden_layers = self.n_layers
|
||||
if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'):
|
||||
self.conv_kernel = self.d_conv
|
||||
if not hasattr(self, 'use_bias') and hasattr(self, 'bias'):
|
||||
self.use_bias = self.bias
|
||||
if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'):
|
||||
self.use_conv_bias = self.conv_bias
|
||||
|
||||
self.intermediate_size = self.expand * self.hidden_size
|
||||
if self.dt_rank == "auto":
|
||||
self.dt_rank = math.ceil(self.hidden_size / 16)
|
||||
|
||||
|
||||
def clamp(x, min=None, max=None):
|
||||
if min is not None:
|
||||
mask_lower = x < min
|
||||
if max is not None:
|
||||
mask_upper = x > max
|
||||
if min is not None:
|
||||
if max is not None:
|
||||
return mx.where(mask_upper, max, mx.where(mask_lower, min, x))
|
||||
return mx.where(mask_lower, min, x)
|
||||
return mx.where(mask_upper, max, x)
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias, padding):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.weight = mx.random.normal(shape=(channels, 1, kernel_size))
|
||||
scale = math.sqrt(1.0 / (channels * kernel_size))
|
||||
self.weight *= scale
|
||||
if bias:
|
||||
self.bias = mx.zeros((channels,))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def __call__(self, x):
|
||||
# x shape is (B, C, L)
|
||||
B, C, L = x.shape
|
||||
|
||||
# Pad the input
|
||||
if self.padding > 0:
|
||||
padding = [(0, 0), (0, 0), (self.padding, self.padding)]
|
||||
x_padded = mx.pad(x, padding)
|
||||
else:
|
||||
x_padded = x
|
||||
|
||||
# Perform depthwise convolution manually
|
||||
out = []
|
||||
for i in range(L):
|
||||
slice = x_padded[:, :, i:i+self.kernel_size]
|
||||
out.append(mx.sum(slice * self.weight, axis=2))
|
||||
|
||||
out = mx.stack(out, axis=2)
|
||||
|
||||
# Apply bias if present
|
||||
if self.bias is not None:
|
||||
out = out + self.bias.reshape(1, -1, 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.ssm_state_size = args.state_size
|
||||
self.conv_kernel_size = args.conv_kernel
|
||||
self.intermediate_size = args.intermediate_size
|
||||
self.time_step_rank = int(args.time_step_rank)
|
||||
self.use_conv_bias = args.use_conv_bias
|
||||
|
||||
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias)
|
||||
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
channels=int(self.intermediate_size),
|
||||
kernel_size=int(self.conv_kernel_size),
|
||||
bias=self.use_conv_bias,
|
||||
padding=int(self.conv_kernel_size - 1)
|
||||
)
|
||||
|
||||
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False)
|
||||
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
||||
|
||||
dt_init_std = args.dt_rank**-0.5 * args.state_size
|
||||
if args.time_step_init_scheme == "constant":
|
||||
self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight)
|
||||
elif args.time_step_init_scheme == "random":
|
||||
self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dt = clamp(mx.exp(
|
||||
mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)
|
||||
), min=args.time_step_floor)
|
||||
inv_dt = dt + mx.log1p(-mx.exp(-dt))
|
||||
self.dt_proj.bias = inv_dt
|
||||
|
||||
A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0)
|
||||
self.A_log = mx.log(A)
|
||||
self.D = mx.ones([self.intermediate_size])
|
||||
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
|
||||
|
||||
def ssm(self, x, h):
|
||||
A = -mx.exp(self.A_log) # (ED, N)
|
||||
D = self.D
|
||||
|
||||
deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
|
||||
|
||||
delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N)
|
||||
delta = nn.softplus(self.dt_proj(delta)) # (B, ED)
|
||||
|
||||
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N)
|
||||
deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N)
|
||||
|
||||
BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N)
|
||||
|
||||
if h is None:
|
||||
h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N)
|
||||
|
||||
h = deltaA * h + BX # (B, ED, N)
|
||||
|
||||
y = mx.sum(h * mx.expand_dims(C, 1), axis=-1) # (B, ED)
|
||||
|
||||
y = y + D * x
|
||||
return y, h
|
||||
|
||||
def __call__(self, x, cache):
|
||||
h, inputs = cache
|
||||
|
||||
x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1)
|
||||
|
||||
# x is now (B, L, C), we need (B, C, L) for conv1d
|
||||
x_cache = x.transpose(0, 2, 1)
|
||||
|
||||
if inputs is None:
|
||||
inputs = mx.zeros((x.shape[0], self.intermediate_size, self.conv_kernel_size - 1))
|
||||
else:
|
||||
inputs = inputs.transpose(0, 2, 1) # Change to (batch, channels, sequence)
|
||||
|
||||
conv_input = mx.concatenate([inputs, x_cache], axis=2)
|
||||
|
||||
x = self.conv1d(conv_input)
|
||||
x = x[:, :, -1] # Take the last element of the sequence
|
||||
|
||||
y, h = self.ssm(x, h)
|
||||
output = y * nn.silu(z[:, -1, :])
|
||||
|
||||
# Update inputs for the next iteration
|
||||
inputs = conv_input[:, :, 1:]
|
||||
|
||||
return self.out_proj(output), (h, inputs.transpose(0, 2, 1))
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.mixer = MambaBlock(args)
|
||||
self.norm = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache):
|
||||
output, cache = self.mixer(self.norm(inputs), cache)
|
||||
output = output + inputs[:, -1, :] # Add residual only for the last time step
|
||||
return output, cache
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm_f = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache):
|
||||
tokens = self.embeddings(inputs)
|
||||
for i, layer in enumerate(self.layers):
|
||||
h, cache[i] = layer(tokens, cache[i])
|
||||
h = self.norm_f(h)
|
||||
return h, cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.backbone = Mamba(args)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache=None):
|
||||
out, cache = self.backbone(inputs, cache)
|
||||
out = self.backbone.embeddings.as_linear(out)
|
||||
return out, cache
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.backbone.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_hidden_layers
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_hidden_layers
|
||||
|
||||
def make_cache(self):
|
||||
# return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)]
|
||||
return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))]
|
Reference in New Issue
Block a user