diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 172aab68..86e4977f 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional import math @@ -32,6 +33,8 @@ class ModelArgs(BaseModelArgs): use_cache: bool use_mambapy: bool = False dt_rank: str = "auto" + tie_word_embeddings: bool = True + def __post_init__(self): if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): @@ -40,12 +43,6 @@ class ModelArgs(BaseModelArgs): self.intermediate_size = self.d_inner if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): self.state_size = self.d_state - if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'): - self.time_step_min = self.dt_min - if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'): - self.time_step_min = self.dt_max - if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'): - self.time_step_min = self.dt_init_floor if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): self.num_hidden_layers = self.n_layer if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): @@ -61,6 +58,56 @@ class ModelArgs(BaseModelArgs): if self.dt_rank == "auto": self.dt_rank = math.ceil(self.hidden_size / 16) +class DepthWiseConv1d(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + padding: int = 0 + ): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.padding = padding + self.weight = mx.random.normal((channels, 1, kernel_size)) + if bias: + self.bias = mx.zeros((channels,)) + else: + self.bias = None + + def __call__(self, x, cache=None): + B, L, C = x.shape + assert C == self.channels, f"Input channels ({C}) must match the initialized channels ({self.channels})." + + w = self.weight # Shape: (C, 1, K) + K = self.kernel_size + total_padding = self.padding + K - 1 + + if cache is not None: + l = [] + if cache.shape[1] < total_padding: + l.append(mx.zeros((B, total_padding - cache.shape[1], C), dtype=x.dtype)) + l.extend([cache, x]) + x = mx.concatenate(l, axis=1) + else: + x = mx.pad(x, [(0, 0), (total_padding, 0), (0, 0)]) + + # Manual depthwise convolution + output = [] + for i in range(K): + slice = x[:, i:i+L, :] + output.append(slice * w[:, 0, i]) + y = mx.sum(mx.stack(output), axis=0) + + # The cache is always total_padding + cache = x[:, max(x.shape[1] - total_padding, 0):, :] + + if self.bias is not None: + y = y + self.bias.reshape(1, 1, -1) + + return y, cache + def clamp(x, min=None, max=None): if min is not None: @@ -72,54 +119,6 @@ def clamp(x, min=None, max=None): return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) return mx.where(mask_lower, min, x) return mx.where(mask_upper, max, x) - - -class Conv1d(nn.Module): - def __init__( - self, - channels: int, - kernel_size: int, - bias: bool = True, - padding: int = 0 - ): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.use_bias = bias - self.padding = padding - - # Change the weight initialization to match the expected shape - self.weight = mx.zeros((kernel_size, 1, channels)) - if self.use_bias: - self.bias = mx.zeros((channels,)) - else: - self.bias = None - - def __call__(self, x, cache=None): - # Use the weight directly without transposing - w = self.weight - if cache is not None: - l = [] - # Pad the cache if needed - if cache.shape[1] < self.kernel_size - 1: - l.append( - mx.zeros( - (x.shape[0], self.kernel_size - 1 - cache.shape[1], self.channels), dtype=x.dtype - ) - ) - l.extend([cache, x]) - x = mx.concatenate(l, axis=1) - y = mx.conv_general(x, w, padding=([0], [0]), groups=self.channels) - else: - y = mx.conv_general(x, w, padding=([self.padding], [0]), groups=self.channels) - - # The cache is always kernel_size - 1 - cache = x[:, max(x.shape[1] - self.kernel_size + 1, 0) :, :] - - if self.use_bias: - y = y + self.bias - - return y, cache class MambaBlock(nn.Module): @@ -127,50 +126,46 @@ class MambaBlock(nn.Module): super().__init__() self.args = args - # projects block input from D to 2*ED (two branches) - self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias) + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias - # short 1d conv over time - self.conv1d = Conv1d( - channels=args.intermediate_size, - kernel_size=args.conv_kernel, - bias=args.use_conv_bias, - padding=args.conv_kernel-1 + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size-1 ) - # projects x to input-dependent Δ, B, C - self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False) + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - # projects Δ from dt_rank to intermediate_size - self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True) - - # dt initialization - # dt weights - dt_init_std = args.dt_rank**-0.5 * args.state_size - + dt_init_std = args.time_step_rank**-0.5 * args.state_size if args.time_step_init_scheme == "constant": self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) elif args.time_step_init_scheme == "random": self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) else: raise NotImplementedError - - # dt bias + dt = clamp(mx.exp( mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) ), min=args.time_step_floor) inv_dt = dt + mx.log1p(-mx.exp(-dt)) self.dt_proj.bias = inv_dt - # S4D real initialization - A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0) - self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ? - self.D = mx.ones([args.intermediate_size]) + A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) - # projects block output from ED back to D - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - def ssm(self, x, h): + def ssm_step(self, x, h): # x : (B, ED) # h : (B, ED, N) @@ -182,7 +177,7 @@ class MambaBlock(nn.Module): deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) delta = nn.softplus(self.dt_proj(delta)) # (B, ED) deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) @@ -191,51 +186,55 @@ class MambaBlock(nn.Module): BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) if h is None: - h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N) + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) h = deltaA * h + BX # (B, ED, N) y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) y = y + D * x - + return y, h def __call__(self, x, cache): - # x : (B, D) - # cache : (h, inputs) - # h : (B, ED, N) - # inputs : (B, conv_kernel-1, ED) - - # y : (B, D) - # cache : (h, inputs) - + # x : (B, T, D) where T is the number of tokens (5 in this case) + # cache : (h, inputs) + # h : (B, ED, N) + # inputs : (B, d_conv-1, ED) + h, inputs = cache - - print("Input shape:", x.shape) - xz = self.in_proj(x) # (B, 2*ED) - xz = xz.reshape(x.shape[0], -1) # Ensure shape is (B, 2*ED) - print("After in_proj shape:", xz.shape) - x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) + B, T, D = x.shape - # x branch - x_cache = mx.expand_dims(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED) + outputs = [] + for t in range(T): + xt = x[:, t, :] # (B, D) + xz = self.in_proj(xt) # (B, 2*ED) + x_t, z_t = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) - x = nn.silu(x) - y, h = self.ssm_step(x, h) + # x branch + x_cache = mx.expand_dims(x_t, 1) # (B, 1, ED) + conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) + conv_out, new_inputs = self.conv1d(conv_input) # (B, d_conv, ED), (B, d_conv-1, ED) + x_t = conv_out[:, -1, :] # (B, ED) - # z branch - z = nn.silu(z) + x_t = nn.silu(x_t) + y_t, h = self.ssm_step(x_t, h) - output = y * z - output = self.out_proj(output) # (B, D) + # z branch + z_t = nn.silu(z_t) - # prepare cache for next call - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED) + output_t = y_t * z_t + output_t = self.out_proj(output_t) # (B, D) + outputs.append(output_t) + + # Update inputs for next token + inputs = new_inputs + + output = mx.stack(outputs, axis=1) # (B, T, D) cache = (h, inputs) - + return output, cache + class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -243,19 +242,12 @@ class ResidualBlock(nn.Module): self.mixer = MambaBlock(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, inputs: mx.array, cache): - # x : (B, D) - # cache : (h, inputs) - # h : (B, ED, N) - # inputs: (B, conv_kernel-1, ED) - - # output : (B, D) - # cache : (h, inputs) - - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs + def __call__(self, x: mx.array, cache): + output, cache = self.mixer(self.norm(x), cache) + output = output + x return output, cache + class Mamba(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -263,22 +255,11 @@ class Mamba(nn.Module): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size) - def __call__(self, tokens: mx.array, caches): - # tokens : (B, L) - - # logits : (B, L, vocab_size) - - x = self.embeddings(tokens) - - # x : (B, L, D) - # caches : [cache(layer) for all layers], cache : (h, inputs) - - # y : (B, L, D) - # caches : [cache(layer) for all layers], cache : (h, inputs) - + def __call__(self, x: mx.array, caches): + x = self.embeddings(x) + print(x.shape) for i, layer in enumerate(self.layers): x, caches[i] = layer(x, caches[i]) - return x, caches @@ -289,10 +270,39 @@ class Model(nn.Module): self.model_type = args.model_type self.backbone = Mamba(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - # out = self.backbone.embeddings.as_linear(out) - return out, cache + # inputs : (B, T) where T is the number of tokens + # caches : [cache(layer) for all layers], cache : (h, inputs) + + if inputs.ndim == 1: + inputs = mx.expand_dims(inputs, 0) # Add batch dimension if not present + + B, T = inputs.shape + x = self.backbone.embeddings(inputs) # (B, T, D) + + for i, layer in enumerate(self.backbone.layers): + x, cache[i] = layer(x, cache[i]) + + x = self.backbone.norm_f(x) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + print(f"Logits shape: {logits.shape}") + # logits : (B, T, vocab_size) + print(logits) + + return logits, cache + + def make_cache(self): + B = 1 # Assuming batch size of 1 for simplicity + return [(None, mx.zeros((B, self.args.conv_kernel-1, self.args.intermediate_size))) + for _ in range(self.args.num_hidden_layers)] @property def layers(self): @@ -306,17 +316,5 @@ class Model(nn.Module): def n_kv_heads(self): return self.args.num_hidden_layers - def make_cache(self): - return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - - def sanitize(self, weights): - for key, value in weights.items(): - if "mixer.conv1d.weight" in key: - # Ensure the weight is in the shape (kernel_size, 1, channels) - if value.shape != (self.args.conv_kernel, 1, self.args.intermediate_size): - weights[key] = value.reshape(self.args.conv_kernel, 1, self.args.intermediate_size) - elif key == "backbone.embeddings.weight": - # Ensure the embedding weight is in the shape (vocab_size, hidden_size) - if value.shape != (self.args.vocab_size, self.args.hidden_size): - weights[key] = value.T - return weights \ No newline at end of file + # def make_cache(self): + # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba1.py b/llms/mlx_lm/models/mamba1.py deleted file mode 100644 index 0b64f967..00000000 --- a/llms/mlx_lm/models/mamba1.py +++ /dev/null @@ -1,293 +0,0 @@ -from dataclasses import dataclass - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int - intermediate_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False - dt_rank: str = "auto" - - def __post_init__(self): - if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): - self.hidden_size = self.d_model - if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): - self.intermediate_size = self.d_inner - if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): - self.state_size = self.d_state - if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'): - self.time_step_min = self.dt_min - if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'): - self.time_step_min = self.dt_max - if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'): - self.time_step_min = self.dt_init_floor - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): - self.num_hidden_layers = self.n_layer - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): - self.num_hidden_layers = self.n_layers - if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): - self.conv_kernel = self.d_conv - if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): - self.use_bias = self.bias - if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): - self.use_conv_bias = self.conv_bias - - self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.bias = bias - self.padding = padding - - self.conv1d = nn.Conv1d( - in_channels=channels, - out_channels=channels, - kernel_size=kernel_size, - bias=True, - padding=padding - ) - indices = mx.arange(channels) - mask = mx.zeros_like(self.conv1d.weight) - mask[indices, :, indices] = 1 - self.conv1d.weight *= mask - - def __call__(self, x): - return self.conv1d(x) - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - # projects block input from D to 2*ED (two branches) - self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias) - - # short 1d conv over time - self.conv1d = DepthWiseConv1d( - channels=args.intermediate_size, - kernel_size=args.conv_kernel, - bias=args.use_conv_bias, - padding=args.conv_kernel-1 - ) - - # projects x to input-dependent Δ, B, C - self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False) - - # projects Δ from dt_rank to intermediate_size - self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True) - - # dt initialization - # dt weights - dt_init_std = args.dt_rank**-0.5 * args.state_size - - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - # dt bias - dt = clamp(mx.exp( - mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) - ), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - self.dt_proj.bias = inv_dt - - # S4D real initialization - A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0) - self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ? - self.D = mx.ones([args.intermediate_size]) - - # projects block output from ED back to D - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - - def ssm(self, x, h): - # x : (B, ED) - # h : (B, ED, N) - - # y : (B, ED) - # h : (B, ED, N) - - A = -mx.exp(self.A_log) # (ED, N) # todo : move out of step (timestep independent) - D = self.D - - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - - if h is None: - h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N) - - h = deltaA * h + BX # (B, ED, N) - - y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) - - y = y + D * x - - return y, h - - def __call__(self, x, cache): - # x : (B, D) - # cache : (h, inputs) - # h : (B, ED, N) - # inputs : (B, conv_kernel-1, ED) - - # y : (B, D) - # cache : (h, inputs) - - h, inputs = cache - - xz = self.in_proj(x) # (B, 2*ED) - x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) - - # x branch - x_cache = mx.expand_dims(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED) - - x = nn.silu(x) - y, h = self.ssm_step(x, h) - - # z branch - z = nn.silu(z) - - output = y * z - output = self.out_proj(output) # (B, D) - - # prepare cache for next call - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED) - cache = (h, inputs) - - return output, cache - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - # x : (B, D) - # cache : (h, inputs) - # h : (B, ED, N) - # inputs: (B, conv_kernel-1, ED) - - # output : (B, D) - # cache : (h, inputs) - - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs - return output, cache - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, tokens: mx.array, caches): - # tokens : (B, L) - - # logits : (B, L, vocab_size) - - x = self.embeddings(tokens) - - # x : (B, L, D) - # caches : [cache(layer) for all layers], cache : (h, inputs) - - # y : (B, L, D) - # caches : [cache(layer) for all layers], cache : (h, inputs) - - for i, layer in enumerate(self.layers): - x, caches[i] = layer(x, caches[i]) - - return x, caches - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - # out = self.backbone.embeddings.as_linear(out) - return out, cache - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers - - def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] - - def sanitize(self, weights): - new_weights = {} - for key, value in weights.items(): - if "mixer.conv1d.weight" in key: - weights[key] = value.T - new_key = key.replace('mixer.conv1d', 'mixer.conv1d.conv1d') - new_weights[new_key] = value - return new_weights \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py deleted file mode 100644 index 04f67d05..00000000 --- a/llms/mlx_lm/models/mamba2.py +++ /dev/null @@ -1,258 +0,0 @@ -from dataclasses import dataclass - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int - intermediate_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False - dt_rank: str = "auto" - - def __post_init__(self): - if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): - self.hidden_size = self.d_model - if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): - self.intermediate_size = self.d_inner - if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): - self.state_size = self.d_state - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): - self.num_hidden_layers = self.n_layer - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): - self.num_hidden_layers = self.n_layers - if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): - self.conv_kernel = self.d_conv - if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): - self.use_bias = self.bias - if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): - self.use_conv_bias = self.conv_bias - - self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.padding = padding - self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) - scale = math.sqrt(1.0 / (channels * kernel_size)) - self.weight *= scale - if bias: - self.bias = mx.zeros((channels,)) - else: - self.bias = None - - def __call__(self, x): - # x shape is (B, C, L) - B, C, L = x.shape - - # Pad the input - if self.padding > 0: - padding = [(0, 0), (0, 0), (self.padding, self.padding)] - x_padded = mx.pad(x, padding) - else: - x_padded = x - - # Perform depthwise convolution manually - out = [] - for i in range(L): - slice = x_padded[:, :, i:i+self.kernel_size] - out.append(mx.sum(slice * self.weight, axis=2)) - - out = mx.stack(out, axis=2) - - # Apply bias if present - if self.bias is not None: - out = out + self.bias.reshape(1, -1, 1) - - return out - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - - self.conv1d = DepthWiseConv1d( - channels=int(self.intermediate_size), - kernel_size=int(self.conv_kernel_size), - bias=self.use_conv_bias, - padding=int(self.conv_kernel_size - 1) - ) - - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp( - mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) - ), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - def ssm(self, x, h): - A = -mx.exp(self.A_log) # (ED, N) - D = self.D - - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - - if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) - - h = deltaA * h + BX # (B, ED, N) - - y = mx.sum(h * mx.expand_dims(C, 1), axis=-1) # (B, ED) - - y = y + D * x - return y, h - - def __call__(self, x, cache): - h, inputs = cache - - x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1) - - # x is now (B, L, C), we need (B, C, L) for conv1d - x_cache = x.transpose(0, 2, 1) - - if inputs is None: - inputs = mx.zeros((x.shape[0], self.intermediate_size, self.conv_kernel_size - 1)) - else: - inputs = inputs.transpose(0, 2, 1) # Change to (batch, channels, sequence) - - conv_input = mx.concatenate([inputs, x_cache], axis=2) - - x = self.conv1d(conv_input) - x = x[:, :, -1] # Take the last element of the sequence - - y, h = self.ssm(x, h) - output = y * nn.silu(z[:, -1, :]) - - # Update inputs for the next iteration - inputs = conv_input[:, :, 1:] - - return self.out_proj(output), (h, inputs.transpose(0, 2, 1)) - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs[:, -1, :] # Add residual only for the last time step - return output, cache - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - tokens = self.embeddings(inputs) - for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - out = self.backbone.embeddings.as_linear(out) - return out, cache - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers - - def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] \ No newline at end of file