Remove unnecessary code and add a test for plamo2

This commit is contained in:
Shunta Saito 2025-02-23 14:43:14 +09:00
parent d4e688edc3
commit 31225f4960
2 changed files with 27 additions and 29 deletions

View File

@ -443,6 +443,7 @@ def get_initial_A(num_heads: int) -> mx.array:
return mx.log(A)
# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219
def selective_state_update_ref(
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
) -> tuple[mx.array, mx.array]:
@ -681,18 +682,13 @@ def _causal_conv1d(
return x, None
def causal_conv1d_update(
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
) -> tuple[mx.array, mx.array]:
# From: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None) -> tuple[mx.array, mx.array]:
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
@ -707,21 +703,8 @@ def causal_conv1d_update(
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
if cache_seqlens is None:
x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen)
conv_state = x_new[:, :, -state_len:]
else:
width_idx = mx.expand_dims(mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0) + mx.expand_dims(
cache_seqlens, axis=1
)
width_idx = mx.expand_dims(mx.remainder(width_idx, state_len), axis=1)
width_idx = mx.broadcast_to(width_idx, (width_idx.shape[0], dim, width_idx.shape[2]))
x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1)
x_new = x_new.astype(weight.dtype)
copy_idx = mx.expand_dims(mx.arange(seqlen, dtype=mx.int64), axis=0) + mx.expand_dims(cache_seqlens, axis=1)
copy_idx = mx.expand_dims(mx.remainder(copy_idx, state_len), axis=1)
copy_idx = mx.broadcast_to(copy_idx, (copy_idx.shape[0], dim, copy_idx.shape[2]))
conv_state.scatter_(2, copy_idx, x)
x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen)
conv_state = x_new[:, :, -state_len:]
assert bias is None
# x_new: (N, C, L) -> (N, L, C)
out = mx.conv1d(
@ -1599,6 +1582,7 @@ class Model(PlamoPreTrainedModel):
def __init__(self, config: ModelArgs) -> None:
super().__init__(config)
self.config = config
self.model_type = config.model_type
self.model = PlamoModel(config)
self.vocab_size = config.vocab_size
@ -1607,8 +1591,6 @@ class Model(PlamoPreTrainedModel):
if not config.tie_word_embeddings:
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
self._prefill = True
# Initialize weights and apply final processing
# self.post_init()
@ -1642,12 +1624,11 @@ class Model(PlamoPreTrainedModel):
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
model_inputs = self.prepare_inputs_for_generation(
input_ids=inputs,
attention_mask=mask,
past_key_values=cache,
use_cache=self.config.use_cache,
)
if self._prefill:
model_inputs["input_ids"] = inputs
self._prefill = False
model_inputs["input_ids"] = inputs
output = self.forward(**model_inputs)
if not isinstance(output, CausalLMOutputWithPast):
raise ValueError(
@ -1770,7 +1751,7 @@ class Model(PlamoPreTrainedModel):
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.astype(mx.int64).cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = mx.where(attention_mask == 0, 1, position_ids)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

View File

@ -183,7 +183,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
if model_type != "mamba":
if model_type not in ("mamba", "plamo2"):
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
outputs = model(inputs, mask=mask)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
@ -372,6 +372,23 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_plamo2(self):
from mlx_lm.models import plamo2
args = plamo2.ModelArgs(
model_type="plamo2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=8,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = plamo2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_stablelm(self):
from mlx_lm.models import stablelm