diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py new file mode 100644 index 00000000..1d8215dd --- /dev/null +++ b/llms/mlx_lm/models/plamo2.py @@ -0,0 +1,601 @@ +# Copyright © 2025 Apple Inc. + +import math +from dataclasses import dataclass +from typing import Any, Optional, Union + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import BaseModelArgs, create_attention_mask + +from .cache import KVCache, MambaCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "plamo2" + hidden_size: int = 4096 + num_hidden_layers: int = 32 + rms_norm_eps: float = 1e-6 + tie_word_embeddings: bool = True + num_attention_heads: int = 32 + num_key_value_heads: int = 4 + hidden_size_per_head: int = 128 + max_position_embeddings: int = 2048 + attention_window_size: int = 2048 + full_attention_idx: Optional[list[int]] = None + mamba_d_state: int = 64 + mamba_d_conv: int = 4 + mamba_num_heads: int = 64 + mamba_step: int = 2 + mamba_chunk_size: int = 256 + mamba_enabled: bool = True + intermediate_size: int = 13312 + vocab_size: int = 32000 + max_position_embeddings: int = 10 * 1024 * 1024 + + +class RMSNorm(nn.Module): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + offset: float = 1.0, + ) -> None: + super().__init__() + self.weight = mx.zeros(hidden_size) + self.variance_epsilon = eps + self.offset = offset + + def __call__(self, hidden_states: mx.array) -> mx.array: + return mx.fast.rms_norm( + hidden_states, self.weight + self.offset, self.variance_epsilon + ) + + +def get_initial_dt_bias(num_heads: int) -> mx.array: + dt_min = 0.001 + dt_max = 0.1 + dt = mx.exp( + mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = mx.clip(dt, a_min=1e-4, a_max=None) + inv_dt = dt + mx.log(-mx.expm1(-dt)) + return inv_dt + + +def get_initial_A(num_heads: int) -> mx.array: + A = mx.arange(1, num_heads + 1, dtype=mx.float32) + return mx.log(A) + + +# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219 +def selective_state_update_ref( + state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False +) -> tuple[mx.array, mx.array]: + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.ndim > 3 + if state.ndim == 3: + state = mx.expand_dims(state, 1) + if x.ndim == 2: + x = mx.expand_dims(x, 1) + if dt.ndim == 2: + dt = mx.expand_dims(dt, 1) + if A.ndim == 2: + A = mx.expand_dims(A, 0) + if B.ndim == 2: + B = mx.expand_dims(B, 1) + if C.ndim == 2: + C = mx.expand_dims(C, 1) + if D is not None and D.ndim == 1: + D = mx.expand_dims(D, 0) + if z is not None and z.ndim == 2: + z = mx.expand_dims(z, 1) + if dt_bias is not None and dt_bias.ndim == 1: + dt_bias = mx.expand_dims(dt_bias, 0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = nn.softplus(dt) if dt_softplus else dt + dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate) + B = mx.reshape( + mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2), + (batch, nheads, dstate), + ) # (batch, nheads, dstate) + C = mx.reshape( + mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2), + (batch, nheads, dstate), + ) # (batch, nheads, dstate) + dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims( + B, axis=-2 + ) # (batch, nheads, dim, dstate) + state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate) + out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C) + if D is not None: + out += (x * D).astype(out.dtype) + out = (out if z is None else out * nn.silu(z)).astype(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out, state + + +def ssd_update_state( + ssm_state: mx.array, + x: mx.array, + dt: mx.array, + A: mx.array, + B: mx.array, + C: mx.array, + D: mx.array, + z: mx.array, + dt_bias: mx.array, + dt_softplus: bool, +) -> tuple[mx.array, mx.array]: + assert ssm_state.dtype == mx.float32 + dtype = x.dtype + + hidden_size_per_head = x.shape[-1] + d_state = B.shape[-1] + A = mx.broadcast_to( + A[:, None, None], (A.shape[0], hidden_size_per_head, d_state) + ).astype(mx.float32) + dt = mx.broadcast_to( + dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head) + ) + dt_bias = mx.broadcast_to( + dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head) + ) + D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head)) + out, ssm_state = selective_state_update_ref( + ssm_state, + x.astype(dtype), + dt.astype(dtype), + A.astype(mx.float32), + B.astype(dtype), + C.astype(dtype), + D.astype(mx.float32), + z.astype(dtype), + dt_bias.astype(mx.float32), + dt_softplus=dt_softplus, + ) + return out[:, None], ssm_state + + +def ssd_chunk_scan_combined( + x: mx.array, + dt: mx.array, + A: mx.array, + B: mx.array, + C: mx.array, + D: mx.array, + z: mx.array, + dt_bias: mx.array, + dt_softplus: bool, + ssm_state: mx.array, +) -> tuple[mx.array, mx.array]: + assert ssm_state.dtype == mx.float32 + length = x.shape[1] + ys = [] + for i in range(length): + y, ssm_state = ssd_update_state( + ssm_state, + x[:, i], + dt[:, i], + A, + B[:, i], + C[:, i], + D if D.ndim == 1 else D[:, i], + z=z[:, i], + dt_bias=dt_bias, + dt_softplus=dt_softplus, + ) + ys.append(y) + return mx.concatenate(ys, axis=1), ssm_state + + +def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: + batch, seqlen, dim = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-2] + x = mx.concatenate([conv_state, x], axis=-2) + conv_state = x[:, -state_len:] + out = mx.conv1d( + x, + weight, + padding=0, + groups=dim, + )[:, -seqlen:] + return nn.silu(out), conv_state + + +class Mamba(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.d_state = config.mamba_d_state + self.d_conv = config.mamba_d_conv + self.chunk_size = config.mamba_chunk_size + self.num_heads = config.mamba_num_heads + self.hidden_size_per_head = config.hidden_size_per_head + + self.intermediate_size = self.num_heads * self.hidden_size_per_head + + self.in_proj = nn.Linear( + self.hidden_size, 2 * self.intermediate_size, bias=False + ) + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=False, + kernel_size=self.d_conv, + groups=self.intermediate_size, + padding=0, + ) + self.dt_dim = max(64, self.hidden_size // 16) + self.bcdt_proj = nn.Linear( + self.intermediate_size, + self.dt_dim + 2 * self.d_state, + bias=False, + ) + self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False) + + self.dt_bias = get_initial_dt_bias(self.num_heads) + self.A_log = get_initial_A(self.num_heads) + self.D = mx.ones(self.num_heads, dtype=mx.float32) + + self.dt_norm_weight = mx.ones(self.dt_dim) + self.B_norm_weight = mx.ones(self.d_state) + self.C_norm_weight = mx.ones(self.d_state) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def __call__( + self, + hidden_states: mx.array, + mask: Optional[mx.array] = None, + cache=None, + ): + bsize, length, _ = hidden_states.shape + + if cache is not None and cache[0] is not None: + conv_state = cache[0] + ssm_state = cache[1] + else: + conv_state = mx.zeros( + (bsize, self.d_conv - 1, self.intermediate_size), + dtype=hidden_states.dtype, + ) + ssm_state = mx.zeros( + (bsize, self.num_heads, self.hidden_size_per_head, self.d_state), + dtype=mx.float32, + ) + + zx = self.in_proj(hidden_states) + zx = zx.reshape(bsize, length, self.num_heads, -1) + # z: (bsize, length, num_heads, hidden_size_per_head) + # x: (bsize, length, num_heads, hidden_size_per_head) + z, x = mx.split( + zx, + [ + self.hidden_size_per_head, + ], + axis=-1, + ) + + x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head) + x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight) + BCdt = self.bcdt_proj(x) + x = x.reshape(bsize, length, self.num_heads, -1) + B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1) + + A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) + dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps) + B = mx.fast.rms_norm(B, self.B_norm_weight, self.config.rms_norm_eps) + C = mx.fast.rms_norm(C, self.C_norm_weight, self.config.rms_norm_eps) + + # (bsize, length, num_heads, 1) + dt = self.dt_proj(dt)[..., None] + + out, ssm_state = ssd_chunk_scan_combined( + x, + dt.reshape(bsize, length, -1), + A, + B, + C, + D=self.D, + z=z, + dt_bias=self.dt_bias, + dt_softplus=True, + ssm_state=ssm_state, + ) + + if cache is not None: + cache[0] = conv_state + cache[1] = ssm_state + y = self.out_proj(out.reshape(bsize, length, -1)) + + return y + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + head_dim = config.hidden_size_per_head + self.max_position_embeddings = config.max_position_embeddings + self.scale = head_dim**-0.5 + + self.q_num_heads = config.num_attention_heads + self.qk_dim = self.v_dim = head_dim + self.k_num_heads = self.v_num_heads = config.num_key_value_heads + assert self.q_num_heads % self.k_num_heads == 0 + self.n_group = self.q_num_heads // self.k_num_heads + + self.q_proj_dim = self.q_num_heads * self.qk_dim + self.k_proj_dim = self.k_num_heads * self.qk_dim + self.v_proj_dim = self.k_num_heads * self.v_dim + self.qkv_proj = nn.Linear( + self.hidden_size, + self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, + bias=False, + ) + self.o_proj = nn.Linear( + self.q_num_heads * self.v_dim, self.hidden_size, bias=False + ) + + self.q_weight = mx.ones((self.q_num_heads, self.qk_dim)) + self.k_weight = mx.ones((self.k_num_heads, self.qk_dim)) + + self.rope = nn.RoPE(self.qk_dim) + + def __call__( + self, + hidden_states: mx.array, + mask: Optional[mx.array] = None, + cache=None, + ): + B, T, _ = hidden_states.shape + + qkv = self.qkv_proj(hidden_states) + q, k, v = mx.split( + qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1 + ) + q = q.reshape(B, T, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3) + k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) + v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) + + q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] + k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] + + if cache is not None: + q = self.rope(q, offset=cache.offset) + k = self.rope(k, offset=cache.offset) + k, v = cache.update_and_fetch(k, v) + else: + q = self.rope(q) + k = self.rope(k) + + output = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=self.scale, + mask=mask, + ) + output = output.transpose(0, 2, 1, 3).reshape( + B, T, self.q_num_heads * self.v_dim + ) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = nn.Linear( + self.hidden_size, self.intermediate_size * 2, bias=False + ) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def __call__(self, x: mx.array) -> mx.array: + h = self.gate_up_proj(x) + hs = mx.split(h, 2, axis=-1) + return self.down_proj(nn.silu(hs[0]) * hs[1]) + + +class PlamoDecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, is_mamba: bool) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.is_mamba = is_mamba + self.mixer: nn.Module + if is_mamba: + self.mixer = Mamba(config) + else: + self.mixer = Attention(config) + self.mlp = MLP(config) + self.pre_mixer_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 + ) + self.post_mixer_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5 + ) + self.pre_mlp_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 + ) + self.post_mlp_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5) + ) + + def __call__( + self, + hidden_states: mx.array, + mask: Optional[mx.array] = None, + cache=None, + ): + residual = hidden_states + hidden_states = self.pre_mixer_norm(hidden_states) + + hidden_states_sa = self.mixer( + hidden_states=hidden_states, + mask=mask, + cache=cache, + ) + + hidden_states_sa = self.post_mixer_norm(hidden_states_sa) + hidden_states = residual + hidden_states_sa + + residual = hidden_states + hidden_states = self.pre_mlp_norm(hidden_states) + + # Fully Connected + hidden_states_mlp = self.mlp(hidden_states) + + # Residual + hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp) + return residual + hidden_states_mlp + + +def is_mamba(config: ModelArgs, i: int) -> bool: + if not config.mamba_enabled: + return False + assert config.mamba_step > 1 + assert i < config.num_hidden_layers + + if config.num_hidden_layers <= (config.mamba_step // 2): + # use attention in last layer + return i != config.num_hidden_layers - 1 + return (i % config.mamba_step) != (config.mamba_step // 2) + + +class PlamoDecoder(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + + self.layers = [ + PlamoDecoderLayer(config, is_mamba=is_mamba(config, i)) + for i in range(config.num_hidden_layers) + ] + + def __call__(self, x: mx.array, mask: mx.array, cache): + for i, decoder_layer in enumerate(self.layers): + x = decoder_layer( + x, + mask=mask, + cache=cache[i], + ) + return x + + +class PlamoModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + + self.config = config + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = PlamoDecoder(config) # type: ignore + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache=None, + ): + batch_size, seq_length = inputs.shape + + h = self.embed_tokens(inputs) + + if mask is None: + mask = create_attention_mask(h, [cache[1]] if cache is not None else None) + + if cache is None: + cache = [None] * len(self.layers.layers) + + # decoder layers + out = self.layers( + h, + mask, + cache, + ) + + return self.norm(out) + + +class Model(nn.Module): + + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.model_type = config.model_type + self.model = PlamoModel(config) + + self.vocab_size = config.vocab_size + + if not config.tie_word_embeddings: + self.lm_head: nn.Module = nn.Linear( + config.hidden_size, vocab_size, bias=False + ) + + def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: + for k, v in weights.items(): + if "conv1d.weight" in k and v.shape[-1] != 1: + weights[k] = v.moveaxis(2, 1) + return weights + + def make_cache(self): + # TODO use RotatingKVCache is not full_attn + # full_attn = self.layer_idx in self.config.full_attention_idx + return [MambaCache() if l.is_mamba else KVCache() for l in self.layers] + + def __call__( + self, inputs: mx.array, mask: Optional[mx.array] = None, cache=None + ) -> mx.array: + outputs = self.model( + inputs=inputs, + mask=None, + cache=cache, + ) + if self.config.tie_word_embeddings: + logits = self.model.embed_tokens.as_linear(outputs) + else: + logits = self.lm_head(outputs) + + return logits + + @property + def layers(self): + return self.model.layers.layers diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1fae76fa..2d760743 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -192,6 +192,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path "tokenizer.model", "*.tiktoken", "*.txt", + "*.jsonl", ], ) ) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index d8cf6820..0c0fc601 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -183,7 +183,7 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - if model_type != "mamba": + if model_type not in ("mamba", "plamo2"): mask = create_causal_mask(inputs.shape[1], 0).astype(t) outputs = model(inputs, mask=mask) self.assertEqual(outputs.shape, (1, 2, vocab_size)) @@ -372,6 +372,23 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_plamo2(self): + from mlx_lm.models import plamo2 + + args = plamo2.ModelArgs( + model_type="plamo2", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=8, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = plamo2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_stablelm(self): from mlx_lm.models import stablelm