mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

* fix rotating kv cache for chat use case * reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat * nit in chat * fix tests * fix tests * fix tests * docs * chat command * comments + docs * Define meta_state on all Cache implementations * fixes + trim_prompt_cache api * fix default model --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
218 lines
6.9 KiB
Python
218 lines
6.9 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .base import BaseModelArgs
|
|
from .cache import MambaCache
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs(BaseModelArgs):
|
|
model_type: str
|
|
vocab_size: int
|
|
hidden_size: int
|
|
intermediate_size: int
|
|
state_size: int
|
|
num_hidden_layers: int
|
|
conv_kernel: int
|
|
use_bias: bool
|
|
use_conv_bias: bool
|
|
time_step_rank: int
|
|
tie_word_embeddings: bool = True
|
|
|
|
def __post_init__(self):
|
|
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
|
self.hidden_size = self.d_model
|
|
if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
|
|
self.intermediate_size = self.d_inner
|
|
if not hasattr(self, "state_size") and hasattr(self, "d_state"):
|
|
self.state_size = self.d_state
|
|
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
|
|
self.num_hidden_layers = self.n_layer
|
|
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
|
|
self.num_hidden_layers = self.n_layers
|
|
if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
|
|
self.conv_kernel = self.d_conv
|
|
if not hasattr(self, "use_bias") and hasattr(self, "bias"):
|
|
self.use_bias = self.bias
|
|
if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
|
|
self.use_conv_bias = self.conv_bias
|
|
|
|
if self.time_step_rank == "auto":
|
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
|
|
|
|
|
class DepthWiseConv1d(nn.Module):
|
|
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.kernel_size = kernel_size
|
|
self.padding = padding
|
|
self.weight = mx.random.normal((self.channels, kernel_size, 1))
|
|
self.bias = mx.zeros((channels,)) if bias else None
|
|
|
|
def __call__(self, x, cache=None):
|
|
B, L, C = x.shape
|
|
groups, K, _ = self.weight.shape
|
|
|
|
if cache is not None:
|
|
x = mx.concatenate([cache, x], axis=1)
|
|
else:
|
|
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
|
|
|
y = mx.conv_general(x, self.weight, groups=groups)
|
|
|
|
if self.bias is not None:
|
|
y = y + self.bias
|
|
|
|
return y, x[:, -K + 1 :, :]
|
|
|
|
|
|
class MambaBlock(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
|
|
self.hidden_size = args.hidden_size
|
|
self.ssm_state_size = args.state_size
|
|
self.conv_kernel_size = args.conv_kernel
|
|
self.intermediate_size = args.intermediate_size
|
|
self.time_step_rank = int(args.time_step_rank)
|
|
self.use_conv_bias = args.use_conv_bias
|
|
|
|
self.in_proj = nn.Linear(
|
|
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
|
)
|
|
|
|
self.conv1d = DepthWiseConv1d(
|
|
channels=self.intermediate_size,
|
|
kernel_size=self.conv_kernel_size,
|
|
bias=self.use_conv_bias,
|
|
padding=self.conv_kernel_size - 1,
|
|
)
|
|
|
|
self.x_proj = nn.Linear(
|
|
self.intermediate_size,
|
|
self.time_step_rank + 2 * self.ssm_state_size,
|
|
bias=False,
|
|
)
|
|
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
|
|
|
A = mx.repeat(
|
|
mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
|
|
repeats=self.intermediate_size,
|
|
axis=0,
|
|
)
|
|
self.A_log = mx.log(A)
|
|
self.D = mx.ones([self.intermediate_size])
|
|
|
|
self.out_proj = nn.Linear(
|
|
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
|
)
|
|
|
|
def ssm_step(self, x, state=None):
|
|
A = -mx.exp(self.A_log)
|
|
D = self.D
|
|
deltaBC = self.x_proj(x)
|
|
delta, B, C = mx.split(
|
|
deltaBC,
|
|
indices_or_sections=[
|
|
self.time_step_rank,
|
|
self.time_step_rank + self.ssm_state_size,
|
|
],
|
|
axis=-1,
|
|
)
|
|
delta = nn.softplus(self.dt_proj(delta))
|
|
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
|
if state is not None:
|
|
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
|
|
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
|
|
y = y + D * x
|
|
return y, new_state
|
|
|
|
def __call__(self, x, cache):
|
|
B, T, D = x.shape
|
|
if cache is None:
|
|
cache = [None, None]
|
|
|
|
outputs = []
|
|
for t in range(T):
|
|
xt = x[:, t, :]
|
|
xz = self.in_proj(xt)
|
|
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
|
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
|
x_t = conv_out.squeeze(1)
|
|
x_t = nn.silu(x_t)
|
|
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
|
z_t = nn.silu(z_t)
|
|
output_t = y_t * z_t
|
|
output_t = self.out_proj(output_t)
|
|
outputs.append(output_t)
|
|
output = mx.stack(outputs, axis=1)
|
|
return output
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.mixer = MambaBlock(args)
|
|
self.norm = nn.RMSNorm(args.hidden_size)
|
|
|
|
def __call__(self, x: mx.array, cache):
|
|
return self.mixer(self.norm(x), cache) + x
|
|
|
|
|
|
class Mamba(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
|
|
self.norm_f = nn.RMSNorm(args.hidden_size)
|
|
|
|
def __call__(self, x: mx.array, cache):
|
|
x = self.embeddings(x)
|
|
if cache is None:
|
|
cache = [None] * len(self.layers)
|
|
for layer, c in zip(self.layers, cache):
|
|
x = layer(x, c)
|
|
return self.norm_f(x)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.model_type = args.model_type
|
|
self.backbone = Mamba(args)
|
|
if not args.tie_word_embeddings:
|
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
|
|
def __call__(self, inputs: mx.array, cache=None):
|
|
B, T = inputs.shape
|
|
|
|
x = self.backbone(inputs, cache)
|
|
|
|
if self.args.tie_word_embeddings:
|
|
logits = self.backbone.embeddings.as_linear(x)
|
|
else:
|
|
logits = self.lm_head(x)
|
|
|
|
return logits
|
|
|
|
def sanitize(self, weights):
|
|
for k, v in weights.items():
|
|
if "conv1d.weight" in k and v.ndim == 3:
|
|
weights[k] = v.moveaxis(2, 1)
|
|
return weights
|
|
|
|
def make_cache(self):
|
|
return [MambaCache() for _ in range(len(self.layers))]
|
|
|
|
@property
|
|
def layers(self):
|
|
return self.backbone.layers
|