mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 11:21:32 +08:00
Merge branch 'ml-explore:main' into adding-support-for-mamba2
This commit is contained in:
commit
42c3cd2084
@ -121,7 +121,7 @@ if __name__ == "__main__":
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("[INFO] Loading")
|
||||
torch_weights = torch.load(torch_path / "pytorch_model.bin")
|
||||
torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
|
||||
print("[INFO] Converting")
|
||||
mlx_weights = {
|
||||
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
||||
|
@ -123,6 +123,18 @@ for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print()
|
||||
```
|
||||
|
||||
#### Sampling
|
||||
|
||||
The `generate` and `stream_generate` functions accept `sampler` and
|
||||
`logits_processors` keyword arguments. A sampler is any callable which accepts
|
||||
a possibly batched logits array and returns an array of sampled tokens. The
|
||||
`logits_processors` must be a list of callables which take the token history
|
||||
and current logits as input and return the processed logits. The logits
|
||||
processors are applied in order.
|
||||
|
||||
Some standard sampling functions and logits processors are provided in
|
||||
`mlx_lm.sample_utils`.
|
||||
|
||||
### Command Line
|
||||
|
||||
You can also use `mlx-lm` from the command line with:
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.21.0"
|
||||
__version__ = "0.21.5"
|
||||
|
@ -181,8 +181,14 @@ def train_model(
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
model.freeze()
|
||||
if args.num_layers > len(model.layers):
|
||||
raise ValueError(
|
||||
f"Requested to train {args.num_layers} layers "
|
||||
f"but the model only has {len(model.layers)} layers."
|
||||
)
|
||||
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||
for l in model.layers[-max(args.num_layers, 0) :]:
|
||||
l.unfreeze()
|
||||
elif args.fine_tune_type in ["lora", "dora"]:
|
||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||
|
601
llms/mlx_lm/models/plamo2.py
Normal file
601
llms/mlx_lm/models/plamo2.py
Normal file
@ -0,0 +1,601 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.base import BaseModelArgs, create_attention_mask
|
||||
|
||||
from .cache import KVCache, MambaCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "plamo2"
|
||||
hidden_size: int = 4096
|
||||
num_hidden_layers: int = 32
|
||||
rms_norm_eps: float = 1e-6
|
||||
tie_word_embeddings: bool = True
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 4
|
||||
hidden_size_per_head: int = 128
|
||||
max_position_embeddings: int = 2048
|
||||
attention_window_size: int = 2048
|
||||
full_attention_idx: Optional[list[int]] = None
|
||||
mamba_d_state: int = 64
|
||||
mamba_d_conv: int = 4
|
||||
mamba_num_heads: int = 64
|
||||
mamba_step: int = 2
|
||||
mamba_chunk_size: int = 256
|
||||
mamba_enabled: bool = True
|
||||
intermediate_size: int = 13312
|
||||
vocab_size: int = 32000
|
||||
max_position_embeddings: int = 10 * 1024 * 1024
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
offset: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = mx.zeros(hidden_size)
|
||||
self.variance_epsilon = eps
|
||||
self.offset = offset
|
||||
|
||||
def __call__(self, hidden_states: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(
|
||||
hidden_states, self.weight + self.offset, self.variance_epsilon
|
||||
)
|
||||
|
||||
|
||||
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
||||
dt_min = 0.001
|
||||
dt_max = 0.1
|
||||
dt = mx.exp(
|
||||
mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
)
|
||||
dt = mx.clip(dt, a_min=1e-4, a_max=None)
|
||||
inv_dt = dt + mx.log(-mx.expm1(-dt))
|
||||
return inv_dt
|
||||
|
||||
|
||||
def get_initial_A(num_heads: int) -> mx.array:
|
||||
A = mx.arange(1, num_heads + 1, dtype=mx.float32)
|
||||
return mx.log(A)
|
||||
|
||||
|
||||
# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219
|
||||
def selective_state_update_ref(
|
||||
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.ndim > 3
|
||||
if state.ndim == 3:
|
||||
state = mx.expand_dims(state, 1)
|
||||
if x.ndim == 2:
|
||||
x = mx.expand_dims(x, 1)
|
||||
if dt.ndim == 2:
|
||||
dt = mx.expand_dims(dt, 1)
|
||||
if A.ndim == 2:
|
||||
A = mx.expand_dims(A, 0)
|
||||
if B.ndim == 2:
|
||||
B = mx.expand_dims(B, 1)
|
||||
if C.ndim == 2:
|
||||
C = mx.expand_dims(C, 1)
|
||||
if D is not None and D.ndim == 1:
|
||||
D = mx.expand_dims(D, 0)
|
||||
if z is not None and z.ndim == 2:
|
||||
z = mx.expand_dims(z, 1)
|
||||
if dt_bias is not None and dt_bias.ndim == 1:
|
||||
dt_bias = mx.expand_dims(dt_bias, 0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
dt = dt + dt_bias
|
||||
dt = nn.softplus(dt) if dt_softplus else dt
|
||||
dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate)
|
||||
B = mx.reshape(
|
||||
mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2),
|
||||
(batch, nheads, dstate),
|
||||
) # (batch, nheads, dstate)
|
||||
C = mx.reshape(
|
||||
mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2),
|
||||
(batch, nheads, dstate),
|
||||
) # (batch, nheads, dstate)
|
||||
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(
|
||||
B, axis=-2
|
||||
) # (batch, nheads, dim, dstate)
|
||||
state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate)
|
||||
out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C)
|
||||
if D is not None:
|
||||
out += (x * D).astype(out.dtype)
|
||||
out = (out if z is None else out * nn.silu(z)).astype(x.dtype)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out, state
|
||||
|
||||
|
||||
def ssd_update_state(
|
||||
ssm_state: mx.array,
|
||||
x: mx.array,
|
||||
dt: mx.array,
|
||||
A: mx.array,
|
||||
B: mx.array,
|
||||
C: mx.array,
|
||||
D: mx.array,
|
||||
z: mx.array,
|
||||
dt_bias: mx.array,
|
||||
dt_softplus: bool,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
assert ssm_state.dtype == mx.float32
|
||||
dtype = x.dtype
|
||||
|
||||
hidden_size_per_head = x.shape[-1]
|
||||
d_state = B.shape[-1]
|
||||
A = mx.broadcast_to(
|
||||
A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)
|
||||
).astype(mx.float32)
|
||||
dt = mx.broadcast_to(
|
||||
dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)
|
||||
)
|
||||
dt_bias = mx.broadcast_to(
|
||||
dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)
|
||||
)
|
||||
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
|
||||
out, ssm_state = selective_state_update_ref(
|
||||
ssm_state,
|
||||
x.astype(dtype),
|
||||
dt.astype(dtype),
|
||||
A.astype(mx.float32),
|
||||
B.astype(dtype),
|
||||
C.astype(dtype),
|
||||
D.astype(mx.float32),
|
||||
z.astype(dtype),
|
||||
dt_bias.astype(mx.float32),
|
||||
dt_softplus=dt_softplus,
|
||||
)
|
||||
return out[:, None], ssm_state
|
||||
|
||||
|
||||
def ssd_chunk_scan_combined(
|
||||
x: mx.array,
|
||||
dt: mx.array,
|
||||
A: mx.array,
|
||||
B: mx.array,
|
||||
C: mx.array,
|
||||
D: mx.array,
|
||||
z: mx.array,
|
||||
dt_bias: mx.array,
|
||||
dt_softplus: bool,
|
||||
ssm_state: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
assert ssm_state.dtype == mx.float32
|
||||
length = x.shape[1]
|
||||
ys = []
|
||||
for i in range(length):
|
||||
y, ssm_state = ssd_update_state(
|
||||
ssm_state,
|
||||
x[:, i],
|
||||
dt[:, i],
|
||||
A,
|
||||
B[:, i],
|
||||
C[:, i],
|
||||
D if D.ndim == 1 else D[:, i],
|
||||
z=z[:, i],
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
)
|
||||
ys.append(y)
|
||||
return mx.concatenate(ys, axis=1), ssm_state
|
||||
|
||||
|
||||
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
|
||||
batch, seqlen, dim = x.shape
|
||||
width = weight.shape[1]
|
||||
state_len = conv_state.shape[-2]
|
||||
x = mx.concatenate([conv_state, x], axis=-2)
|
||||
conv_state = x[:, -state_len:]
|
||||
out = mx.conv1d(
|
||||
x,
|
||||
weight,
|
||||
padding=0,
|
||||
groups=dim,
|
||||
)[:, -seqlen:]
|
||||
return nn.silu(out), conv_state
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.d_state = config.mamba_d_state
|
||||
self.d_conv = config.mamba_d_conv
|
||||
self.chunk_size = config.mamba_chunk_size
|
||||
self.num_heads = config.mamba_num_heads
|
||||
self.hidden_size_per_head = config.hidden_size_per_head
|
||||
|
||||
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
||||
|
||||
self.in_proj = nn.Linear(
|
||||
self.hidden_size, 2 * self.intermediate_size, bias=False
|
||||
)
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.intermediate_size,
|
||||
out_channels=self.intermediate_size,
|
||||
bias=False,
|
||||
kernel_size=self.d_conv,
|
||||
groups=self.intermediate_size,
|
||||
padding=0,
|
||||
)
|
||||
self.dt_dim = max(64, self.hidden_size // 16)
|
||||
self.bcdt_proj = nn.Linear(
|
||||
self.intermediate_size,
|
||||
self.dt_dim + 2 * self.d_state,
|
||||
bias=False,
|
||||
)
|
||||
self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
|
||||
|
||||
self.dt_bias = get_initial_dt_bias(self.num_heads)
|
||||
self.A_log = get_initial_A(self.num_heads)
|
||||
self.D = mx.ones(self.num_heads, dtype=mx.float32)
|
||||
|
||||
self.dt_norm_weight = mx.ones(self.dt_dim)
|
||||
self.B_norm_weight = mx.ones(self.d_state)
|
||||
self.C_norm_weight = mx.ones(self.d_state)
|
||||
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
bsize, length, _ = hidden_states.shape
|
||||
|
||||
if cache is not None and cache[0] is not None:
|
||||
conv_state = cache[0]
|
||||
ssm_state = cache[1]
|
||||
else:
|
||||
conv_state = mx.zeros(
|
||||
(bsize, self.d_conv - 1, self.intermediate_size),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
ssm_state = mx.zeros(
|
||||
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
||||
dtype=mx.float32,
|
||||
)
|
||||
|
||||
zx = self.in_proj(hidden_states)
|
||||
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
||||
# z: (bsize, length, num_heads, hidden_size_per_head)
|
||||
# x: (bsize, length, num_heads, hidden_size_per_head)
|
||||
z, x = mx.split(
|
||||
zx,
|
||||
[
|
||||
self.hidden_size_per_head,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head)
|
||||
x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight)
|
||||
BCdt = self.bcdt_proj(x)
|
||||
x = x.reshape(bsize, length, self.num_heads, -1)
|
||||
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
|
||||
|
||||
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
||||
dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps)
|
||||
B = mx.fast.rms_norm(B, self.B_norm_weight, self.config.rms_norm_eps)
|
||||
C = mx.fast.rms_norm(C, self.C_norm_weight, self.config.rms_norm_eps)
|
||||
|
||||
# (bsize, length, num_heads, 1)
|
||||
dt = self.dt_proj(dt)[..., None]
|
||||
|
||||
out, ssm_state = ssd_chunk_scan_combined(
|
||||
x,
|
||||
dt.reshape(bsize, length, -1),
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=self.D,
|
||||
z=z,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
ssm_state=ssm_state,
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
cache[0] = conv_state
|
||||
cache[1] = ssm_state
|
||||
y = self.out_proj(out.reshape(bsize, length, -1))
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
head_dim = config.hidden_size_per_head
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_num_heads = config.num_attention_heads
|
||||
self.qk_dim = self.v_dim = head_dim
|
||||
self.k_num_heads = self.v_num_heads = config.num_key_value_heads
|
||||
assert self.q_num_heads % self.k_num_heads == 0
|
||||
self.n_group = self.q_num_heads // self.k_num_heads
|
||||
|
||||
self.q_proj_dim = self.q_num_heads * self.qk_dim
|
||||
self.k_proj_dim = self.k_num_heads * self.qk_dim
|
||||
self.v_proj_dim = self.k_num_heads * self.v_dim
|
||||
self.qkv_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.q_weight = mx.ones((self.q_num_heads, self.qk_dim))
|
||||
self.k_weight = mx.ones((self.k_num_heads, self.qk_dim))
|
||||
|
||||
self.rope = nn.RoPE(self.qk_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
B, T, _ = hidden_states.shape
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = mx.split(
|
||||
qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1
|
||||
)
|
||||
q = q.reshape(B, T, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
||||
|
||||
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None]
|
||||
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None]
|
||||
|
||||
if cache is not None:
|
||||
q = self.rope(q, offset=cache.offset)
|
||||
k = self.rope(k, offset=cache.offset)
|
||||
k, v = cache.update_and_fetch(k, v)
|
||||
else:
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale,
|
||||
mask=mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(
|
||||
B, T, self.q_num_heads * self.v_dim
|
||||
)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_up_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size * 2, bias=False
|
||||
)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
h = self.gate_up_proj(x)
|
||||
hs = mx.split(h, 2, axis=-1)
|
||||
return self.down_proj(nn.silu(hs[0]) * hs[1])
|
||||
|
||||
|
||||
class PlamoDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, is_mamba: bool) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.is_mamba = is_mamba
|
||||
self.mixer: nn.Module
|
||||
if is_mamba:
|
||||
self.mixer = Mamba(config)
|
||||
else:
|
||||
self.mixer = Attention(config)
|
||||
self.mlp = MLP(config)
|
||||
self.pre_mixer_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||
)
|
||||
self.post_mixer_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5
|
||||
)
|
||||
self.pre_mlp_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||
)
|
||||
self.post_mlp_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_mixer_norm(hidden_states)
|
||||
|
||||
hidden_states_sa = self.mixer(
|
||||
hidden_states=hidden_states,
|
||||
mask=mask,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
|
||||
hidden_states = residual + hidden_states_sa
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_mlp_norm(hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
|
||||
# Residual
|
||||
hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
|
||||
return residual + hidden_states_mlp
|
||||
|
||||
|
||||
def is_mamba(config: ModelArgs, i: int) -> bool:
|
||||
if not config.mamba_enabled:
|
||||
return False
|
||||
assert config.mamba_step > 1
|
||||
assert i < config.num_hidden_layers
|
||||
|
||||
if config.num_hidden_layers <= (config.mamba_step // 2):
|
||||
# use attention in last layer
|
||||
return i != config.num_hidden_layers - 1
|
||||
return (i % config.mamba_step) != (config.mamba_step // 2)
|
||||
|
||||
|
||||
class PlamoDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layers = [
|
||||
PlamoDecoderLayer(config, is_mamba=is_mamba(config, i))
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array, cache):
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
x = decoder_layer(
|
||||
x,
|
||||
mask=mask,
|
||||
cache=cache[i],
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class PlamoModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
batch_size, seq_length = inputs.shape
|
||||
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, [cache[1]] if cache is not None else None)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers.layers)
|
||||
|
||||
# decoder layers
|
||||
out = self.layers(
|
||||
h,
|
||||
mask,
|
||||
cache,
|
||||
)
|
||||
|
||||
return self.norm(out)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_type = config.model_type
|
||||
self.model = PlamoModel(config)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
config.hidden_size, vocab_size, bias=False
|
||||
)
|
||||
|
||||
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
# TODO use RotatingKVCache is not full_attn
|
||||
# full_attn = self.layer_idx in self.config.full_attention_idx
|
||||
return [MambaCache() if l.is_mamba else KVCache() for l in self.layers]
|
||||
|
||||
def __call__(
|
||||
self, inputs: mx.array, mask: Optional[mx.array] = None, cache=None
|
||||
) -> mx.array:
|
||||
outputs = self.model(
|
||||
inputs=inputs,
|
||||
mask=None,
|
||||
cache=cache,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
logits = self.model.embed_tokens.as_linear(outputs)
|
||||
else:
|
||||
logits = self.lm_head(outputs)
|
||||
|
||||
return logits
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
@ -52,11 +52,6 @@ def linear_to_lora_layers(
|
||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||
Default: ``False``
|
||||
"""
|
||||
if num_layers > len(model.layers):
|
||||
raise ValueError(
|
||||
f"Requested {num_layers} LoRA layers "
|
||||
f"but the model only has {len(model.layers)} layers."
|
||||
)
|
||||
|
||||
def to_lora(layer):
|
||||
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
|
||||
@ -161,7 +156,7 @@ def linear_to_lora_layers(
|
||||
else:
|
||||
raise ValueError(f"Lora does not support {model.model_type}")
|
||||
|
||||
for l in model.layers[-min(num_layers, 0) :]:
|
||||
for l in model.layers[-max(num_layers, 0) :]:
|
||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||
if lora_layers:
|
||||
l.update_modules(tree_unflatten(lora_layers))
|
||||
|
@ -192,6 +192,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jsonl",
|
||||
],
|
||||
)
|
||||
)
|
||||
@ -382,8 +383,8 @@ def speculative_generate_step(
|
||||
and a bool indicating if the token was generated by the draft model
|
||||
"""
|
||||
|
||||
y = prompt
|
||||
tokens = None
|
||||
y = prompt.astype(mx.uint32)
|
||||
prev_tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
if prompt_cache is None:
|
||||
@ -404,17 +405,38 @@ def speculative_generate_step(
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _process_and_sample(tokens, logits):
|
||||
if logits_processors:
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs
|
||||
|
||||
def _step(model, cache, y, n_predict=1):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = logits[:, -n_predict:, :]
|
||||
|
||||
quantize_cache_fn(cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
logprobs = logprobs.squeeze(0)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs
|
||||
if logits_processors:
|
||||
nonlocal prev_tokens
|
||||
out_y, out_logprobs = [], []
|
||||
if n_predict > 1:
|
||||
y = y[: -(n_predict - 1)]
|
||||
for i in range(n_predict):
|
||||
prev_tokens = (
|
||||
mx.concat([prev_tokens, y]) if prev_tokens is not None else y
|
||||
)
|
||||
y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :])
|
||||
out_y.append(y)
|
||||
out_logprobs.append(logprobs)
|
||||
return mx.concatenate(out_y, axis=0), mx.concatenate(
|
||||
out_logprobs, axis=0
|
||||
)
|
||||
else:
|
||||
return _process_and_sample(None, logits.squeeze(0))
|
||||
|
||||
def _prefill(model, cache, y):
|
||||
while y.size > prefill_step_size:
|
||||
@ -451,8 +473,9 @@ def speculative_generate_step(
|
||||
while True:
|
||||
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
||||
draft_tokens = _draft_generate(draft_y, num_draft)
|
||||
if prev_tokens is not None:
|
||||
prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1]
|
||||
y = mx.concatenate([y, draft_tokens])
|
||||
|
||||
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
||||
mx.eval(tokens, draft_tokens)
|
||||
draft_tokens = draft_tokens.tolist()
|
||||
@ -485,6 +508,8 @@ def speculative_generate_step(
|
||||
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
||||
)
|
||||
|
||||
if prev_tokens is not None:
|
||||
prev_tokens = prev_tokens[: -max(num_draft - n, 1)]
|
||||
_rewind_cache(num_draft, n)
|
||||
finally:
|
||||
_rewind_cache(num_draft, n)
|
||||
|
@ -183,7 +183,7 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
if model_type != "mamba":
|
||||
if model_type not in ("mamba", "plamo2"):
|
||||
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
|
||||
outputs = model(inputs, mask=mask)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
@ -372,6 +372,23 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_plamo2(self):
|
||||
from mlx_lm.models import plamo2
|
||||
|
||||
args = plamo2.ModelArgs(
|
||||
model_type="plamo2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=8,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = plamo2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_stablelm(self):
|
||||
from mlx_lm.models import stablelm
|
||||
|
||||
|
@ -8,7 +8,6 @@ import datasets
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
@ -40,26 +39,21 @@ class TransformerLM(nn.Module):
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
window_size = context_size + 1 # include target
|
||||
samples = tokens - window_size + 1
|
||||
X = np.lib.stride_tricks.as_strided(
|
||||
dataset,
|
||||
shape=(samples, window_size),
|
||||
strides=(dataset.itemsize, dataset.itemsize),
|
||||
)
|
||||
return X[:, :-1], X[:, 1:]
|
||||
samples = dataset.size // window_size
|
||||
dataset = dataset[: samples * window_size]
|
||||
return mx.array(dataset.reshape(samples, -1))
|
||||
|
||||
|
||||
def iterate_batches(batch_size, context_size, dataset):
|
||||
inputs, targets = to_samples(context_size, dataset)
|
||||
inputs = to_samples(context_size, dataset)
|
||||
s = 0
|
||||
while True:
|
||||
if s == 0:
|
||||
# Reset permutation:
|
||||
perm = np.random.permutation(inputs.shape[0])
|
||||
perm = mx.random.permutation(inputs.shape[0])
|
||||
ids = perm[s : s + batch_size]
|
||||
yield inputs[ids], targets[ids]
|
||||
yield inputs[ids]
|
||||
s += batch_size
|
||||
if s >= inputs.shape[0]:
|
||||
s = 0
|
||||
@ -84,45 +78,42 @@ def main(args):
|
||||
)
|
||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||
|
||||
def loss_fn(model, x, y, reduce=True):
|
||||
def loss_fn(model, inputs, reduction="mean"):
|
||||
x, y = inputs[..., :-1], inputs[..., 1:]
|
||||
logits = model(x)
|
||||
losses = nn.losses.cross_entropy(logits, y)
|
||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||
return nn.losses.cross_entropy(logits, y, reduction=reduction)
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
def eval_fn(dataset):
|
||||
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
||||
inputs = to_samples(context_size, dataset)
|
||||
loss = 0
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
bx, by = map(mx.array, (bx, by))
|
||||
losses = loss_fn(model, bx, by, reduce=False)
|
||||
loss += mx.sum(losses).item()
|
||||
return loss / len(targets)
|
||||
for s in range(0, inputs.shape[0], batch_size):
|
||||
losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum")
|
||||
loss += losses.item()
|
||||
return loss / (inputs.size - inputs.shape[0])
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(inputs, targets):
|
||||
def step(inputs):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, inputs, targets)
|
||||
loss, grads = loss_and_grad_fn(model, inputs)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||
inputs, targets = map(mx.array, (inputs, targets))
|
||||
for it, inputs in zip(range(args.num_iters), train_iterator):
|
||||
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
||||
loss = step(inputs, targets)
|
||||
loss = step(inputs)
|
||||
mx.eval(state)
|
||||
losses.append(loss.item())
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
train_loss = sum(losses) / len(losses)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
|
Loading…
Reference in New Issue
Block a user