mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:08:37 +08:00
Remove unnecessary code and add a test for plamo2
This commit is contained in:
parent
d4e688edc3
commit
31225f4960
@ -443,6 +443,7 @@ def get_initial_A(num_heads: int) -> mx.array:
|
||||
return mx.log(A)
|
||||
|
||||
|
||||
# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219
|
||||
def selective_state_update_ref(
|
||||
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
@ -681,18 +682,13 @@ def _causal_conv1d(
|
||||
return x, None
|
||||
|
||||
|
||||
def causal_conv1d_update(
|
||||
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
# From: https://github.com/Dao-AILab/causal-conv1d/blob/82867a9d2e6907cc0f637ac6aff318f696838548/causal_conv1d/causal_conv1d_interface.py#L206
|
||||
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None) -> tuple[mx.array, mx.array]:
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the conv_state starting at the index
|
||||
@cache_seqlens % state_len before performing the convolution.
|
||||
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
@ -707,21 +703,8 @@ def causal_conv1d_update(
|
||||
state_len = conv_state.shape[-1]
|
||||
assert conv_state.shape == (batch, dim, state_len)
|
||||
assert weight.shape == (dim, width)
|
||||
if cache_seqlens is None:
|
||||
x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen)
|
||||
conv_state = x_new[:, :, -state_len:]
|
||||
else:
|
||||
width_idx = mx.expand_dims(mx.arange(-(width - 1), 0, dtype=mx.int64), axis=0) + mx.expand_dims(
|
||||
cache_seqlens, axis=1
|
||||
)
|
||||
width_idx = mx.expand_dims(mx.remainder(width_idx, state_len), axis=1)
|
||||
width_idx = mx.broadcast_to(width_idx, (width_idx.shape[0], dim, width_idx.shape[2]))
|
||||
x_new = mx.concatenate([conv_state.gather(2, width_idx), x], axis=-1)
|
||||
x_new = x_new.astype(weight.dtype)
|
||||
copy_idx = mx.expand_dims(mx.arange(seqlen, dtype=mx.int64), axis=0) + mx.expand_dims(cache_seqlens, axis=1)
|
||||
copy_idx = mx.expand_dims(mx.remainder(copy_idx, state_len), axis=1)
|
||||
copy_idx = mx.broadcast_to(copy_idx, (copy_idx.shape[0], dim, copy_idx.shape[2]))
|
||||
conv_state.scatter_(2, copy_idx, x)
|
||||
x_new = mx.concatenate([conv_state, x], axis=-1).astype(weight.dtype) # (batch, dim, state_len + seqlen)
|
||||
conv_state = x_new[:, :, -state_len:]
|
||||
assert bias is None
|
||||
# x_new: (N, C, L) -> (N, L, C)
|
||||
out = mx.conv1d(
|
||||
@ -1599,6 +1582,7 @@ class Model(PlamoPreTrainedModel):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.model_type = config.model_type
|
||||
self.model = PlamoModel(config)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -1607,8 +1591,6 @@ class Model(PlamoPreTrainedModel):
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
||||
|
||||
self._prefill = True
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
# self.post_init()
|
||||
|
||||
@ -1642,12 +1624,11 @@ class Model(PlamoPreTrainedModel):
|
||||
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids=inputs,
|
||||
attention_mask=mask,
|
||||
past_key_values=cache,
|
||||
use_cache=self.config.use_cache,
|
||||
)
|
||||
if self._prefill:
|
||||
model_inputs["input_ids"] = inputs
|
||||
self._prefill = False
|
||||
model_inputs["input_ids"] = inputs
|
||||
output = self.forward(**model_inputs)
|
||||
if not isinstance(output, CausalLMOutputWithPast):
|
||||
raise ValueError(
|
||||
@ -1770,7 +1751,7 @@ class Model(PlamoPreTrainedModel):
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.astype(mx.int64).cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
position_ids = mx.where(attention_mask == 0, 1, position_ids)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
|
@ -183,7 +183,7 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
if model_type != "mamba":
|
||||
if model_type not in ("mamba", "plamo2"):
|
||||
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
|
||||
outputs = model(inputs, mask=mask)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
@ -372,6 +372,23 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_plamo2(self):
|
||||
from mlx_lm.models import plamo2
|
||||
|
||||
args = plamo2.ModelArgs(
|
||||
model_type="plamo2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=8,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = plamo2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_stablelm(self):
|
||||
from mlx_lm.models import stablelm
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user