mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Add plamo-2-1b model (#1283)
* Add pfnet/plamo-2-1b * Fix cache.py to support non-top level layers * Use mlx's BaseModelArgs * Fix model * Use sanitize() * Remove unnecessary changes * Add plamo2.py * Apply formatter * Fix some part * Allow a cache obj defined externally * Fix channel first weights to channel last for right use of MLX's conv1d * Remove unused code part * Give all inputs when it's the first time call of model * Fix import * Include .jsonl files to download from Huggingface hub * Fix reference to layers * Remove unnecessary code and add a test for plamo2 * Do not pass mask to prepare_inputs_for_generation * Fix to use repeat instead of tile * Add state property to PlamoCache * Add __iter__ and __next__ methods to PlamoCache * cleanup * cleanup * fix --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
09b641aaa7
commit
c37e26a1a3
601
llms/mlx_lm/models/plamo2.py
Normal file
601
llms/mlx_lm/models/plamo2.py
Normal file
@ -0,0 +1,601 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.base import BaseModelArgs, create_attention_mask
|
||||
|
||||
from .cache import KVCache, MambaCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "plamo2"
|
||||
hidden_size: int = 4096
|
||||
num_hidden_layers: int = 32
|
||||
rms_norm_eps: float = 1e-6
|
||||
tie_word_embeddings: bool = True
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 4
|
||||
hidden_size_per_head: int = 128
|
||||
max_position_embeddings: int = 2048
|
||||
attention_window_size: int = 2048
|
||||
full_attention_idx: Optional[list[int]] = None
|
||||
mamba_d_state: int = 64
|
||||
mamba_d_conv: int = 4
|
||||
mamba_num_heads: int = 64
|
||||
mamba_step: int = 2
|
||||
mamba_chunk_size: int = 256
|
||||
mamba_enabled: bool = True
|
||||
intermediate_size: int = 13312
|
||||
vocab_size: int = 32000
|
||||
max_position_embeddings: int = 10 * 1024 * 1024
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
offset: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = mx.zeros(hidden_size)
|
||||
self.variance_epsilon = eps
|
||||
self.offset = offset
|
||||
|
||||
def __call__(self, hidden_states: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(
|
||||
hidden_states, self.weight + self.offset, self.variance_epsilon
|
||||
)
|
||||
|
||||
|
||||
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
||||
dt_min = 0.001
|
||||
dt_max = 0.1
|
||||
dt = mx.exp(
|
||||
mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
)
|
||||
dt = mx.clip(dt, a_min=1e-4, a_max=None)
|
||||
inv_dt = dt + mx.log(-mx.expm1(-dt))
|
||||
return inv_dt
|
||||
|
||||
|
||||
def get_initial_A(num_heads: int) -> mx.array:
|
||||
A = mx.arange(1, num_heads + 1, dtype=mx.float32)
|
||||
return mx.log(A)
|
||||
|
||||
|
||||
# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219
|
||||
def selective_state_update_ref(
|
||||
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.ndim > 3
|
||||
if state.ndim == 3:
|
||||
state = mx.expand_dims(state, 1)
|
||||
if x.ndim == 2:
|
||||
x = mx.expand_dims(x, 1)
|
||||
if dt.ndim == 2:
|
||||
dt = mx.expand_dims(dt, 1)
|
||||
if A.ndim == 2:
|
||||
A = mx.expand_dims(A, 0)
|
||||
if B.ndim == 2:
|
||||
B = mx.expand_dims(B, 1)
|
||||
if C.ndim == 2:
|
||||
C = mx.expand_dims(C, 1)
|
||||
if D is not None and D.ndim == 1:
|
||||
D = mx.expand_dims(D, 0)
|
||||
if z is not None and z.ndim == 2:
|
||||
z = mx.expand_dims(z, 1)
|
||||
if dt_bias is not None and dt_bias.ndim == 1:
|
||||
dt_bias = mx.expand_dims(dt_bias, 0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
dt = dt + dt_bias
|
||||
dt = nn.softplus(dt) if dt_softplus else dt
|
||||
dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate)
|
||||
B = mx.reshape(
|
||||
mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2),
|
||||
(batch, nheads, dstate),
|
||||
) # (batch, nheads, dstate)
|
||||
C = mx.reshape(
|
||||
mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2),
|
||||
(batch, nheads, dstate),
|
||||
) # (batch, nheads, dstate)
|
||||
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(
|
||||
B, axis=-2
|
||||
) # (batch, nheads, dim, dstate)
|
||||
state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate)
|
||||
out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C)
|
||||
if D is not None:
|
||||
out += (x * D).astype(out.dtype)
|
||||
out = (out if z is None else out * nn.silu(z)).astype(x.dtype)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out, state
|
||||
|
||||
|
||||
def ssd_update_state(
|
||||
ssm_state: mx.array,
|
||||
x: mx.array,
|
||||
dt: mx.array,
|
||||
A: mx.array,
|
||||
B: mx.array,
|
||||
C: mx.array,
|
||||
D: mx.array,
|
||||
z: mx.array,
|
||||
dt_bias: mx.array,
|
||||
dt_softplus: bool,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
assert ssm_state.dtype == mx.float32
|
||||
dtype = x.dtype
|
||||
|
||||
hidden_size_per_head = x.shape[-1]
|
||||
d_state = B.shape[-1]
|
||||
A = mx.broadcast_to(
|
||||
A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)
|
||||
).astype(mx.float32)
|
||||
dt = mx.broadcast_to(
|
||||
dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)
|
||||
)
|
||||
dt_bias = mx.broadcast_to(
|
||||
dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)
|
||||
)
|
||||
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
|
||||
out, ssm_state = selective_state_update_ref(
|
||||
ssm_state,
|
||||
x.astype(dtype),
|
||||
dt.astype(dtype),
|
||||
A.astype(mx.float32),
|
||||
B.astype(dtype),
|
||||
C.astype(dtype),
|
||||
D.astype(mx.float32),
|
||||
z.astype(dtype),
|
||||
dt_bias.astype(mx.float32),
|
||||
dt_softplus=dt_softplus,
|
||||
)
|
||||
return out[:, None], ssm_state
|
||||
|
||||
|
||||
def ssd_chunk_scan_combined(
|
||||
x: mx.array,
|
||||
dt: mx.array,
|
||||
A: mx.array,
|
||||
B: mx.array,
|
||||
C: mx.array,
|
||||
D: mx.array,
|
||||
z: mx.array,
|
||||
dt_bias: mx.array,
|
||||
dt_softplus: bool,
|
||||
ssm_state: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
assert ssm_state.dtype == mx.float32
|
||||
length = x.shape[1]
|
||||
ys = []
|
||||
for i in range(length):
|
||||
y, ssm_state = ssd_update_state(
|
||||
ssm_state,
|
||||
x[:, i],
|
||||
dt[:, i],
|
||||
A,
|
||||
B[:, i],
|
||||
C[:, i],
|
||||
D if D.ndim == 1 else D[:, i],
|
||||
z=z[:, i],
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
)
|
||||
ys.append(y)
|
||||
return mx.concatenate(ys, axis=1), ssm_state
|
||||
|
||||
|
||||
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
|
||||
batch, seqlen, dim = x.shape
|
||||
width = weight.shape[1]
|
||||
state_len = conv_state.shape[-2]
|
||||
x = mx.concatenate([conv_state, x], axis=-2)
|
||||
conv_state = x[:, -state_len:]
|
||||
out = mx.conv1d(
|
||||
x,
|
||||
weight,
|
||||
padding=0,
|
||||
groups=dim,
|
||||
)[:, -seqlen:]
|
||||
return nn.silu(out), conv_state
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.d_state = config.mamba_d_state
|
||||
self.d_conv = config.mamba_d_conv
|
||||
self.chunk_size = config.mamba_chunk_size
|
||||
self.num_heads = config.mamba_num_heads
|
||||
self.hidden_size_per_head = config.hidden_size_per_head
|
||||
|
||||
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
||||
|
||||
self.in_proj = nn.Linear(
|
||||
self.hidden_size, 2 * self.intermediate_size, bias=False
|
||||
)
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.intermediate_size,
|
||||
out_channels=self.intermediate_size,
|
||||
bias=False,
|
||||
kernel_size=self.d_conv,
|
||||
groups=self.intermediate_size,
|
||||
padding=0,
|
||||
)
|
||||
self.dt_dim = max(64, self.hidden_size // 16)
|
||||
self.bcdt_proj = nn.Linear(
|
||||
self.intermediate_size,
|
||||
self.dt_dim + 2 * self.d_state,
|
||||
bias=False,
|
||||
)
|
||||
self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
|
||||
|
||||
self.dt_bias = get_initial_dt_bias(self.num_heads)
|
||||
self.A_log = get_initial_A(self.num_heads)
|
||||
self.D = mx.ones(self.num_heads, dtype=mx.float32)
|
||||
|
||||
self.dt_norm_weight = mx.ones(self.dt_dim)
|
||||
self.B_norm_weight = mx.ones(self.d_state)
|
||||
self.C_norm_weight = mx.ones(self.d_state)
|
||||
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
bsize, length, _ = hidden_states.shape
|
||||
|
||||
if cache is not None and cache[0] is not None:
|
||||
conv_state = cache[0]
|
||||
ssm_state = cache[1]
|
||||
else:
|
||||
conv_state = mx.zeros(
|
||||
(bsize, self.d_conv - 1, self.intermediate_size),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
ssm_state = mx.zeros(
|
||||
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
||||
dtype=mx.float32,
|
||||
)
|
||||
|
||||
zx = self.in_proj(hidden_states)
|
||||
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
||||
# z: (bsize, length, num_heads, hidden_size_per_head)
|
||||
# x: (bsize, length, num_heads, hidden_size_per_head)
|
||||
z, x = mx.split(
|
||||
zx,
|
||||
[
|
||||
self.hidden_size_per_head,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head)
|
||||
x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight)
|
||||
BCdt = self.bcdt_proj(x)
|
||||
x = x.reshape(bsize, length, self.num_heads, -1)
|
||||
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
|
||||
|
||||
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
||||
dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps)
|
||||
B = mx.fast.rms_norm(B, self.B_norm_weight, self.config.rms_norm_eps)
|
||||
C = mx.fast.rms_norm(C, self.C_norm_weight, self.config.rms_norm_eps)
|
||||
|
||||
# (bsize, length, num_heads, 1)
|
||||
dt = self.dt_proj(dt)[..., None]
|
||||
|
||||
out, ssm_state = ssd_chunk_scan_combined(
|
||||
x,
|
||||
dt.reshape(bsize, length, -1),
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=self.D,
|
||||
z=z,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
ssm_state=ssm_state,
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
cache[0] = conv_state
|
||||
cache[1] = ssm_state
|
||||
y = self.out_proj(out.reshape(bsize, length, -1))
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
head_dim = config.hidden_size_per_head
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_num_heads = config.num_attention_heads
|
||||
self.qk_dim = self.v_dim = head_dim
|
||||
self.k_num_heads = self.v_num_heads = config.num_key_value_heads
|
||||
assert self.q_num_heads % self.k_num_heads == 0
|
||||
self.n_group = self.q_num_heads // self.k_num_heads
|
||||
|
||||
self.q_proj_dim = self.q_num_heads * self.qk_dim
|
||||
self.k_proj_dim = self.k_num_heads * self.qk_dim
|
||||
self.v_proj_dim = self.k_num_heads * self.v_dim
|
||||
self.qkv_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.q_weight = mx.ones((self.q_num_heads, self.qk_dim))
|
||||
self.k_weight = mx.ones((self.k_num_heads, self.qk_dim))
|
||||
|
||||
self.rope = nn.RoPE(self.qk_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
B, T, _ = hidden_states.shape
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = mx.split(
|
||||
qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1
|
||||
)
|
||||
q = q.reshape(B, T, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
||||
|
||||
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None]
|
||||
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None]
|
||||
|
||||
if cache is not None:
|
||||
q = self.rope(q, offset=cache.offset)
|
||||
k = self.rope(k, offset=cache.offset)
|
||||
k, v = cache.update_and_fetch(k, v)
|
||||
else:
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale,
|
||||
mask=mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(
|
||||
B, T, self.q_num_heads * self.v_dim
|
||||
)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_up_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size * 2, bias=False
|
||||
)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
h = self.gate_up_proj(x)
|
||||
hs = mx.split(h, 2, axis=-1)
|
||||
return self.down_proj(nn.silu(hs[0]) * hs[1])
|
||||
|
||||
|
||||
class PlamoDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, is_mamba: bool) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.is_mamba = is_mamba
|
||||
self.mixer: nn.Module
|
||||
if is_mamba:
|
||||
self.mixer = Mamba(config)
|
||||
else:
|
||||
self.mixer = Attention(config)
|
||||
self.mlp = MLP(config)
|
||||
self.pre_mixer_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||
)
|
||||
self.post_mixer_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5
|
||||
)
|
||||
self.pre_mlp_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||
)
|
||||
self.post_mlp_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_mixer_norm(hidden_states)
|
||||
|
||||
hidden_states_sa = self.mixer(
|
||||
hidden_states=hidden_states,
|
||||
mask=mask,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
|
||||
hidden_states = residual + hidden_states_sa
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_mlp_norm(hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
|
||||
# Residual
|
||||
hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
|
||||
return residual + hidden_states_mlp
|
||||
|
||||
|
||||
def is_mamba(config: ModelArgs, i: int) -> bool:
|
||||
if not config.mamba_enabled:
|
||||
return False
|
||||
assert config.mamba_step > 1
|
||||
assert i < config.num_hidden_layers
|
||||
|
||||
if config.num_hidden_layers <= (config.mamba_step // 2):
|
||||
# use attention in last layer
|
||||
return i != config.num_hidden_layers - 1
|
||||
return (i % config.mamba_step) != (config.mamba_step // 2)
|
||||
|
||||
|
||||
class PlamoDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layers = [
|
||||
PlamoDecoderLayer(config, is_mamba=is_mamba(config, i))
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array, cache):
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
x = decoder_layer(
|
||||
x,
|
||||
mask=mask,
|
||||
cache=cache[i],
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class PlamoModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
batch_size, seq_length = inputs.shape
|
||||
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, [cache[1]] if cache is not None else None)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers.layers)
|
||||
|
||||
# decoder layers
|
||||
out = self.layers(
|
||||
h,
|
||||
mask,
|
||||
cache,
|
||||
)
|
||||
|
||||
return self.norm(out)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_type = config.model_type
|
||||
self.model = PlamoModel(config)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
config.hidden_size, vocab_size, bias=False
|
||||
)
|
||||
|
||||
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
# TODO use RotatingKVCache is not full_attn
|
||||
# full_attn = self.layer_idx in self.config.full_attention_idx
|
||||
return [MambaCache() if l.is_mamba else KVCache() for l in self.layers]
|
||||
|
||||
def __call__(
|
||||
self, inputs: mx.array, mask: Optional[mx.array] = None, cache=None
|
||||
) -> mx.array:
|
||||
outputs = self.model(
|
||||
inputs=inputs,
|
||||
mask=None,
|
||||
cache=cache,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
logits = self.model.embed_tokens.as_linear(outputs)
|
||||
else:
|
||||
logits = self.lm_head(outputs)
|
||||
|
||||
return logits
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
@ -192,6 +192,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jsonl",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
@ -183,7 +183,7 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
if model_type != "mamba":
|
||||
if model_type not in ("mamba", "plamo2"):
|
||||
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
|
||||
outputs = model(inputs, mask=mask)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
@ -372,6 +372,23 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_plamo2(self):
|
||||
from mlx_lm.models import plamo2
|
||||
|
||||
args = plamo2.ModelArgs(
|
||||
model_type="plamo2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=8,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = plamo2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_stablelm(self):
|
||||
from mlx_lm.models import stablelm
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user