diff --git a/clip/convert.py b/clip/convert.py index 29bac22e..976d7494 100644 --- a/clip/convert.py +++ b/clip/convert.py @@ -121,7 +121,7 @@ if __name__ == "__main__": mlx_path.mkdir(parents=True, exist_ok=True) print("[INFO] Loading") - torch_weights = torch.load(torch_path / "pytorch_model.bin") + torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True) print("[INFO] Converting") mlx_weights = { k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items() diff --git a/llms/README.md b/llms/README.md index 4f7451c1..e2d1db59 100644 --- a/llms/README.md +++ b/llms/README.md @@ -123,6 +123,18 @@ for response in stream_generate(model, tokenizer, prompt, max_tokens=512): print() ``` +#### Sampling + +The `generate` and `stream_generate` functions accept `sampler` and +`logits_processors` keyword arguments. A sampler is any callable which accepts +a possibly batched logits array and returns an array of sampled tokens. The +`logits_processors` must be a list of callables which take the token history +and current logits as input and return the processed logits. The logits +processors are applied in order. + +Some standard sampling functions and logits processors are provided in +`mlx_lm.sample_utils`. + ### Command Line You can also use `mlx-lm` from the command line with: diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index b2f98e6f..89e6cd00 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.21.0" +__version__ = "0.21.5" diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index abc5dfa9..def3b6dd 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -181,8 +181,14 @@ def train_model( training_callback: TrainingCallback = None, ): model.freeze() + if args.num_layers > len(model.layers): + raise ValueError( + f"Requested to train {args.num_layers} layers " + f"but the model only has {len(model.layers)} layers." + ) + if args.fine_tune_type == "full": - for l in model.layers[-min(args.num_layers, 0) :]: + for l in model.layers[-max(args.num_layers, 0) :]: l.unfreeze() elif args.fine_tune_type in ["lora", "dora"]: # Convert linear layers to lora/dora layers and unfreeze in the process diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py new file mode 100644 index 00000000..1d8215dd --- /dev/null +++ b/llms/mlx_lm/models/plamo2.py @@ -0,0 +1,601 @@ +# Copyright © 2025 Apple Inc. + +import math +from dataclasses import dataclass +from typing import Any, Optional, Union + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import BaseModelArgs, create_attention_mask + +from .cache import KVCache, MambaCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "plamo2" + hidden_size: int = 4096 + num_hidden_layers: int = 32 + rms_norm_eps: float = 1e-6 + tie_word_embeddings: bool = True + num_attention_heads: int = 32 + num_key_value_heads: int = 4 + hidden_size_per_head: int = 128 + max_position_embeddings: int = 2048 + attention_window_size: int = 2048 + full_attention_idx: Optional[list[int]] = None + mamba_d_state: int = 64 + mamba_d_conv: int = 4 + mamba_num_heads: int = 64 + mamba_step: int = 2 + mamba_chunk_size: int = 256 + mamba_enabled: bool = True + intermediate_size: int = 13312 + vocab_size: int = 32000 + max_position_embeddings: int = 10 * 1024 * 1024 + + +class RMSNorm(nn.Module): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + offset: float = 1.0, + ) -> None: + super().__init__() + self.weight = mx.zeros(hidden_size) + self.variance_epsilon = eps + self.offset = offset + + def __call__(self, hidden_states: mx.array) -> mx.array: + return mx.fast.rms_norm( + hidden_states, self.weight + self.offset, self.variance_epsilon + ) + + +def get_initial_dt_bias(num_heads: int) -> mx.array: + dt_min = 0.001 + dt_max = 0.1 + dt = mx.exp( + mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = mx.clip(dt, a_min=1e-4, a_max=None) + inv_dt = dt + mx.log(-mx.expm1(-dt)) + return inv_dt + + +def get_initial_A(num_heads: int) -> mx.array: + A = mx.arange(1, num_heads + 1, dtype=mx.float32) + return mx.log(A) + + +# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219 +def selective_state_update_ref( + state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False +) -> tuple[mx.array, mx.array]: + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.ndim > 3 + if state.ndim == 3: + state = mx.expand_dims(state, 1) + if x.ndim == 2: + x = mx.expand_dims(x, 1) + if dt.ndim == 2: + dt = mx.expand_dims(dt, 1) + if A.ndim == 2: + A = mx.expand_dims(A, 0) + if B.ndim == 2: + B = mx.expand_dims(B, 1) + if C.ndim == 2: + C = mx.expand_dims(C, 1) + if D is not None and D.ndim == 1: + D = mx.expand_dims(D, 0) + if z is not None and z.ndim == 2: + z = mx.expand_dims(z, 1) + if dt_bias is not None and dt_bias.ndim == 1: + dt_bias = mx.expand_dims(dt_bias, 0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = nn.softplus(dt) if dt_softplus else dt + dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate) + B = mx.reshape( + mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2), + (batch, nheads, dstate), + ) # (batch, nheads, dstate) + C = mx.reshape( + mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2), + (batch, nheads, dstate), + ) # (batch, nheads, dstate) + dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims( + B, axis=-2 + ) # (batch, nheads, dim, dstate) + state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate) + out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C) + if D is not None: + out += (x * D).astype(out.dtype) + out = (out if z is None else out * nn.silu(z)).astype(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out, state + + +def ssd_update_state( + ssm_state: mx.array, + x: mx.array, + dt: mx.array, + A: mx.array, + B: mx.array, + C: mx.array, + D: mx.array, + z: mx.array, + dt_bias: mx.array, + dt_softplus: bool, +) -> tuple[mx.array, mx.array]: + assert ssm_state.dtype == mx.float32 + dtype = x.dtype + + hidden_size_per_head = x.shape[-1] + d_state = B.shape[-1] + A = mx.broadcast_to( + A[:, None, None], (A.shape[0], hidden_size_per_head, d_state) + ).astype(mx.float32) + dt = mx.broadcast_to( + dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head) + ) + dt_bias = mx.broadcast_to( + dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head) + ) + D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head)) + out, ssm_state = selective_state_update_ref( + ssm_state, + x.astype(dtype), + dt.astype(dtype), + A.astype(mx.float32), + B.astype(dtype), + C.astype(dtype), + D.astype(mx.float32), + z.astype(dtype), + dt_bias.astype(mx.float32), + dt_softplus=dt_softplus, + ) + return out[:, None], ssm_state + + +def ssd_chunk_scan_combined( + x: mx.array, + dt: mx.array, + A: mx.array, + B: mx.array, + C: mx.array, + D: mx.array, + z: mx.array, + dt_bias: mx.array, + dt_softplus: bool, + ssm_state: mx.array, +) -> tuple[mx.array, mx.array]: + assert ssm_state.dtype == mx.float32 + length = x.shape[1] + ys = [] + for i in range(length): + y, ssm_state = ssd_update_state( + ssm_state, + x[:, i], + dt[:, i], + A, + B[:, i], + C[:, i], + D if D.ndim == 1 else D[:, i], + z=z[:, i], + dt_bias=dt_bias, + dt_softplus=dt_softplus, + ) + ys.append(y) + return mx.concatenate(ys, axis=1), ssm_state + + +def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: + batch, seqlen, dim = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-2] + x = mx.concatenate([conv_state, x], axis=-2) + conv_state = x[:, -state_len:] + out = mx.conv1d( + x, + weight, + padding=0, + groups=dim, + )[:, -seqlen:] + return nn.silu(out), conv_state + + +class Mamba(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.d_state = config.mamba_d_state + self.d_conv = config.mamba_d_conv + self.chunk_size = config.mamba_chunk_size + self.num_heads = config.mamba_num_heads + self.hidden_size_per_head = config.hidden_size_per_head + + self.intermediate_size = self.num_heads * self.hidden_size_per_head + + self.in_proj = nn.Linear( + self.hidden_size, 2 * self.intermediate_size, bias=False + ) + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=False, + kernel_size=self.d_conv, + groups=self.intermediate_size, + padding=0, + ) + self.dt_dim = max(64, self.hidden_size // 16) + self.bcdt_proj = nn.Linear( + self.intermediate_size, + self.dt_dim + 2 * self.d_state, + bias=False, + ) + self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False) + + self.dt_bias = get_initial_dt_bias(self.num_heads) + self.A_log = get_initial_A(self.num_heads) + self.D = mx.ones(self.num_heads, dtype=mx.float32) + + self.dt_norm_weight = mx.ones(self.dt_dim) + self.B_norm_weight = mx.ones(self.d_state) + self.C_norm_weight = mx.ones(self.d_state) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def __call__( + self, + hidden_states: mx.array, + mask: Optional[mx.array] = None, + cache=None, + ): + bsize, length, _ = hidden_states.shape + + if cache is not None and cache[0] is not None: + conv_state = cache[0] + ssm_state = cache[1] + else: + conv_state = mx.zeros( + (bsize, self.d_conv - 1, self.intermediate_size), + dtype=hidden_states.dtype, + ) + ssm_state = mx.zeros( + (bsize, self.num_heads, self.hidden_size_per_head, self.d_state), + dtype=mx.float32, + ) + + zx = self.in_proj(hidden_states) + zx = zx.reshape(bsize, length, self.num_heads, -1) + # z: (bsize, length, num_heads, hidden_size_per_head) + # x: (bsize, length, num_heads, hidden_size_per_head) + z, x = mx.split( + zx, + [ + self.hidden_size_per_head, + ], + axis=-1, + ) + + x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head) + x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight) + BCdt = self.bcdt_proj(x) + x = x.reshape(bsize, length, self.num_heads, -1) + B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1) + + A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) + dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps) + B = mx.fast.rms_norm(B, self.B_norm_weight, self.config.rms_norm_eps) + C = mx.fast.rms_norm(C, self.C_norm_weight, self.config.rms_norm_eps) + + # (bsize, length, num_heads, 1) + dt = self.dt_proj(dt)[..., None] + + out, ssm_state = ssd_chunk_scan_combined( + x, + dt.reshape(bsize, length, -1), + A, + B, + C, + D=self.D, + z=z, + dt_bias=self.dt_bias, + dt_softplus=True, + ssm_state=ssm_state, + ) + + if cache is not None: + cache[0] = conv_state + cache[1] = ssm_state + y = self.out_proj(out.reshape(bsize, length, -1)) + + return y + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + head_dim = config.hidden_size_per_head + self.max_position_embeddings = config.max_position_embeddings + self.scale = head_dim**-0.5 + + self.q_num_heads = config.num_attention_heads + self.qk_dim = self.v_dim = head_dim + self.k_num_heads = self.v_num_heads = config.num_key_value_heads + assert self.q_num_heads % self.k_num_heads == 0 + self.n_group = self.q_num_heads // self.k_num_heads + + self.q_proj_dim = self.q_num_heads * self.qk_dim + self.k_proj_dim = self.k_num_heads * self.qk_dim + self.v_proj_dim = self.k_num_heads * self.v_dim + self.qkv_proj = nn.Linear( + self.hidden_size, + self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, + bias=False, + ) + self.o_proj = nn.Linear( + self.q_num_heads * self.v_dim, self.hidden_size, bias=False + ) + + self.q_weight = mx.ones((self.q_num_heads, self.qk_dim)) + self.k_weight = mx.ones((self.k_num_heads, self.qk_dim)) + + self.rope = nn.RoPE(self.qk_dim) + + def __call__( + self, + hidden_states: mx.array, + mask: Optional[mx.array] = None, + cache=None, + ): + B, T, _ = hidden_states.shape + + qkv = self.qkv_proj(hidden_states) + q, k, v = mx.split( + qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1 + ) + q = q.reshape(B, T, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3) + k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) + v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) + + q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] + k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] + + if cache is not None: + q = self.rope(q, offset=cache.offset) + k = self.rope(k, offset=cache.offset) + k, v = cache.update_and_fetch(k, v) + else: + q = self.rope(q) + k = self.rope(k) + + output = mx.fast.scaled_dot_product_attention( + q, + k, + v, + scale=self.scale, + mask=mask, + ) + output = output.transpose(0, 2, 1, 3).reshape( + B, T, self.q_num_heads * self.v_dim + ) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = nn.Linear( + self.hidden_size, self.intermediate_size * 2, bias=False + ) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def __call__(self, x: mx.array) -> mx.array: + h = self.gate_up_proj(x) + hs = mx.split(h, 2, axis=-1) + return self.down_proj(nn.silu(hs[0]) * hs[1]) + + +class PlamoDecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, is_mamba: bool) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.is_mamba = is_mamba + self.mixer: nn.Module + if is_mamba: + self.mixer = Mamba(config) + else: + self.mixer = Attention(config) + self.mlp = MLP(config) + self.pre_mixer_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 + ) + self.post_mixer_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5 + ) + self.pre_mlp_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 + ) + self.post_mlp_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5) + ) + + def __call__( + self, + hidden_states: mx.array, + mask: Optional[mx.array] = None, + cache=None, + ): + residual = hidden_states + hidden_states = self.pre_mixer_norm(hidden_states) + + hidden_states_sa = self.mixer( + hidden_states=hidden_states, + mask=mask, + cache=cache, + ) + + hidden_states_sa = self.post_mixer_norm(hidden_states_sa) + hidden_states = residual + hidden_states_sa + + residual = hidden_states + hidden_states = self.pre_mlp_norm(hidden_states) + + # Fully Connected + hidden_states_mlp = self.mlp(hidden_states) + + # Residual + hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp) + return residual + hidden_states_mlp + + +def is_mamba(config: ModelArgs, i: int) -> bool: + if not config.mamba_enabled: + return False + assert config.mamba_step > 1 + assert i < config.num_hidden_layers + + if config.num_hidden_layers <= (config.mamba_step // 2): + # use attention in last layer + return i != config.num_hidden_layers - 1 + return (i % config.mamba_step) != (config.mamba_step // 2) + + +class PlamoDecoder(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + + self.layers = [ + PlamoDecoderLayer(config, is_mamba=is_mamba(config, i)) + for i in range(config.num_hidden_layers) + ] + + def __call__(self, x: mx.array, mask: mx.array, cache): + for i, decoder_layer in enumerate(self.layers): + x = decoder_layer( + x, + mask=mask, + cache=cache[i], + ) + return x + + +class PlamoModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + + self.config = config + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = PlamoDecoder(config) # type: ignore + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + mask: Optional[mx.array] = None, + cache=None, + ): + batch_size, seq_length = inputs.shape + + h = self.embed_tokens(inputs) + + if mask is None: + mask = create_attention_mask(h, [cache[1]] if cache is not None else None) + + if cache is None: + cache = [None] * len(self.layers.layers) + + # decoder layers + out = self.layers( + h, + mask, + cache, + ) + + return self.norm(out) + + +class Model(nn.Module): + + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + self.model_type = config.model_type + self.model = PlamoModel(config) + + self.vocab_size = config.vocab_size + + if not config.tie_word_embeddings: + self.lm_head: nn.Module = nn.Linear( + config.hidden_size, vocab_size, bias=False + ) + + def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: + for k, v in weights.items(): + if "conv1d.weight" in k and v.shape[-1] != 1: + weights[k] = v.moveaxis(2, 1) + return weights + + def make_cache(self): + # TODO use RotatingKVCache is not full_attn + # full_attn = self.layer_idx in self.config.full_attention_idx + return [MambaCache() if l.is_mamba else KVCache() for l in self.layers] + + def __call__( + self, inputs: mx.array, mask: Optional[mx.array] = None, cache=None + ) -> mx.array: + outputs = self.model( + inputs=inputs, + mask=None, + cache=cache, + ) + if self.config.tie_word_embeddings: + logits = self.model.embed_tokens.as_linear(outputs) + else: + logits = self.lm_head(outputs) + + return logits + + @property + def layers(self): + return self.model.layers.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 7c08b001..6031d763 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -52,11 +52,6 @@ def linear_to_lora_layers( use_dora (bool): If True, uses DoRA instead of LoRA. Default: ``False`` """ - if num_layers > len(model.layers): - raise ValueError( - f"Requested {num_layers} LoRA layers " - f"but the model only has {len(model.layers)} layers." - ) def to_lora(layer): if isinstance(layer, (nn.Linear, nn.QuantizedLinear)): @@ -161,7 +156,7 @@ def linear_to_lora_layers( else: raise ValueError(f"Lora does not support {model.model_type}") - for l in model.layers[-min(num_layers, 0) :]: + for l in model.layers[-max(num_layers, 0) :]: lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] if lora_layers: l.update_modules(tree_unflatten(lora_layers)) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 64813123..2d760743 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -192,6 +192,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path "tokenizer.model", "*.tiktoken", "*.txt", + "*.jsonl", ], ) ) @@ -382,8 +383,8 @@ def speculative_generate_step( and a bool indicating if the token was generated by the draft model """ - y = prompt - tokens = None + y = prompt.astype(mx.uint32) + prev_tokens = None # Create the KV cache for generation if prompt_cache is None: @@ -404,17 +405,38 @@ def speculative_generate_step( kv_bits=kv_bits, ) + def _process_and_sample(tokens, logits): + if logits_processors: + for processor in logits_processors: + logits = processor(tokens, logits) + + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + y = sampler(logprobs) + return y, logprobs + def _step(model, cache, y, n_predict=1): with mx.stream(generation_stream): logits = model(y[None], cache=cache) logits = logits[:, -n_predict:, :] quantize_cache_fn(cache) - - logprobs = logits - mx.logsumexp(logits, keepdims=True) - logprobs = logprobs.squeeze(0) - y = sampler(logprobs) - return y, logprobs + if logits_processors: + nonlocal prev_tokens + out_y, out_logprobs = [], [] + if n_predict > 1: + y = y[: -(n_predict - 1)] + for i in range(n_predict): + prev_tokens = ( + mx.concat([prev_tokens, y]) if prev_tokens is not None else y + ) + y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :]) + out_y.append(y) + out_logprobs.append(logprobs) + return mx.concatenate(out_y, axis=0), mx.concatenate( + out_logprobs, axis=0 + ) + else: + return _process_and_sample(None, logits.squeeze(0)) def _prefill(model, cache, y): while y.size > prefill_step_size: @@ -451,8 +473,9 @@ def speculative_generate_step( while True: num_draft = min(max_tokens - ntoks, num_draft_tokens) draft_tokens = _draft_generate(draft_y, num_draft) + if prev_tokens is not None: + prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1] y = mx.concatenate([y, draft_tokens]) - tokens, logprobs = _step(model, model_cache, y, num_draft + 1) mx.eval(tokens, draft_tokens) draft_tokens = draft_tokens.tolist() @@ -485,6 +508,8 @@ def speculative_generate_step( [mx.array(draft_tokens[-1:], mx.uint32), draft_y] ) + if prev_tokens is not None: + prev_tokens = prev_tokens[: -max(num_draft - n, 1)] _rewind_cache(num_draft, n) finally: _rewind_cache(num_draft, n) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index d8cf6820..0c0fc601 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -183,7 +183,7 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - if model_type != "mamba": + if model_type not in ("mamba", "plamo2"): mask = create_causal_mask(inputs.shape[1], 0).astype(t) outputs = model(inputs, mask=mask) self.assertEqual(outputs.shape, (1, 2, vocab_size)) @@ -372,6 +372,23 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_plamo2(self): + from mlx_lm.models import plamo2 + + args = plamo2.ModelArgs( + model_type="plamo2", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=8, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = plamo2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_stablelm(self): from mlx_lm.models import stablelm diff --git a/transformer_lm/main.py b/transformer_lm/main.py index dc725cbe..7ff5b73f 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -8,7 +8,6 @@ import datasets import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim -import numpy as np from mlx.utils import tree_flatten @@ -40,26 +39,21 @@ class TransformerLM(nn.Module): def to_samples(context_size, dataset): - tokens = dataset.size window_size = context_size + 1 # include target - samples = tokens - window_size + 1 - X = np.lib.stride_tricks.as_strided( - dataset, - shape=(samples, window_size), - strides=(dataset.itemsize, dataset.itemsize), - ) - return X[:, :-1], X[:, 1:] + samples = dataset.size // window_size + dataset = dataset[: samples * window_size] + return mx.array(dataset.reshape(samples, -1)) def iterate_batches(batch_size, context_size, dataset): - inputs, targets = to_samples(context_size, dataset) + inputs = to_samples(context_size, dataset) s = 0 while True: if s == 0: # Reset permutation: - perm = np.random.permutation(inputs.shape[0]) + perm = mx.random.permutation(inputs.shape[0]) ids = perm[s : s + batch_size] - yield inputs[ids], targets[ids] + yield inputs[ids] s += batch_size if s >= inputs.shape[0]: s = 0 @@ -84,45 +78,42 @@ def main(args): ) print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters") - def loss_fn(model, x, y, reduce=True): + def loss_fn(model, inputs, reduction="mean"): + x, y = inputs[..., :-1], inputs[..., 1:] logits = model(x) - losses = nn.losses.cross_entropy(logits, y) - return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2)) + return nn.losses.cross_entropy(logits, y, reduction=reduction) optimizer = optim.AdamW( learning_rate=args.learning_rate, weight_decay=args.weight_decay ) def eval_fn(dataset): - inputs, targets = map(mx.array, to_samples(context_size, dataset)) + inputs = to_samples(context_size, dataset) loss = 0 - for s in range(0, targets.shape[0], batch_size): - bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] - bx, by = map(mx.array, (bx, by)) - losses = loss_fn(model, bx, by, reduce=False) - loss += mx.sum(losses).item() - return loss / len(targets) + for s in range(0, inputs.shape[0], batch_size): + losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum") + loss += losses.item() + return loss / (inputs.size - inputs.shape[0]) state = [model.state, optimizer.state] @partial(mx.compile, inputs=state, outputs=state) - def step(inputs, targets): + def step(inputs): loss_and_grad_fn = nn.value_and_grad(model, loss_fn) - loss, grads = loss_and_grad_fn(model, inputs, targets) + loss, grads = loss_and_grad_fn(model, inputs) optimizer.update(model, grads) return loss train_iterator = iterate_batches(batch_size, context_size, train) losses = [] tic = time.perf_counter() - for it, (inputs, targets) in zip(range(args.num_iters), train_iterator): - inputs, targets = map(mx.array, (inputs, targets)) + for it, inputs in zip(range(args.num_iters), train_iterator): optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate - loss = step(inputs, targets) + loss = step(inputs) mx.eval(state) losses.append(loss.item()) if (it + 1) % steps_per_report == 0: - train_loss = np.mean(losses) + train_loss = sum(losses) / len(losses) toc = time.perf_counter() print( f"Iter {it + 1}: Train loss {train_loss:.3f}, "