diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index f9a0c40c..055ccb80 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -443,6 +443,7 @@ def get_initial_A(num_heads: int) -> mx.array: return mx.log(A) +# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219 def selective_state_update_ref( state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False ) -> tuple[mx.array, mx.array]: @@ -681,18 +682,13 @@ def _causal_conv1d( return x, None -def causal_conv1d_update( - x, conv_state, weight, bias=None, activation=None, cache_seqlens=None -) -> tuple[mx.array, mx.array]: +# From: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206 +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None) -> tuple[mx.array, mx.array]: """ x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. out: (batch, dim) or (batch, dim, seqlen) """ @@ -707,21 +703,8 @@ def causal_conv1d_update( state_len = conv_state.shape[-1] assert conv_state.shape == (batch, dim, state_len) assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen) - conv_state = x_new[:, :, -state_len:] - else: - width_idx = mx.expand_dims(mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0) + mx.expand_dims( - cache_seqlens, axis=1 - ) - width_idx = mx.expand_dims(mx.remainder(width_idx, state_len), axis=1) - width_idx = mx.broadcast_to(width_idx, (width_idx.shape[0], dim, width_idx.shape[2])) - x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1) - x_new = x_new.astype(weight.dtype) - copy_idx = mx.expand_dims(mx.arange(seqlen, dtype=mx.int64), axis=0) + mx.expand_dims(cache_seqlens, axis=1) - copy_idx = mx.expand_dims(mx.remainder(copy_idx, state_len), axis=1) - copy_idx = mx.broadcast_to(copy_idx, (copy_idx.shape[0], dim, copy_idx.shape[2])) - conv_state.scatter_(2, copy_idx, x) + x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state = x_new[:, :, -state_len:] assert bias is None # x_new: (N, C, L) -> (N, L, C) out = mx.conv1d( @@ -1599,6 +1582,7 @@ class Model(PlamoPreTrainedModel): def __init__(self, config: ModelArgs) -> None: super().__init__(config) self.config = config + self.model_type = config.model_type self.model = PlamoModel(config) self.vocab_size = config.vocab_size @@ -1607,8 +1591,6 @@ class Model(PlamoPreTrainedModel): if not config.tie_word_embeddings: self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False) - self._prefill = True - # Initialize weights and apply final processing # self.post_init() @@ -1642,12 +1624,11 @@ class Model(PlamoPreTrainedModel): def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array: model_inputs = self.prepare_inputs_for_generation( input_ids=inputs, + attention_mask=mask, past_key_values=cache, use_cache=self.config.use_cache, ) - if self._prefill: - model_inputs["input_ids"] = inputs - self._prefill = False + model_inputs["input_ids"] = inputs output = self.forward(**model_inputs) if not isinstance(output, CausalLMOutputWithPast): raise ValueError( @@ -1770,7 +1751,7 @@ class Model(PlamoPreTrainedModel): if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.astype(mx.int64).cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = mx.where(attention_mask == 0, 1, position_ids) if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index d8cf6820..0c0fc601 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -183,7 +183,7 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - if model_type != "mamba": + if model_type not in ("mamba", "plamo2"): mask = create_causal_mask(inputs.shape[1], 0).astype(t) outputs = model(inputs, mask=mask) self.assertEqual(outputs.shape, (1, 2, vocab_size)) @@ -372,6 +372,23 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_plamo2(self): + from mlx_lm.models import plamo2 + + args = plamo2.ModelArgs( + model_type="plamo2", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=8, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = plamo2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_stablelm(self): from mlx_lm.models import stablelm