diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index da28aa2a..773a6839 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -498,11 +498,11 @@ def selective_state_update_ref( dt = nn.softplus(dt) if dt_softplus else dt dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate) B = mx.reshape( - mx.tile(mx.expand_dims(B, axis=2), (1, 1, nheads // ngroups, 1)), + mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2), (batch, nheads, dstate), ) # (batch, nheads, dstate) C = mx.reshape( - mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)), + mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2), (batch, nheads, dstate), ) # (batch, nheads, dstate) dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate) @@ -1070,8 +1070,8 @@ class Attention(nn.Module): # expand shared kv assert self.k_num_heads == self.v_num_heads - key_states = mx.tile(key_states, (1, self.n_group, 1, 1)) - value_states = mx.tile(value_states, (1, self.n_group, 1, 1)) + key_states = mx.repeat(key_states, self.n_group, 1) + value_states = mx.repeat(value_states, self.n_group, 1) full_attn = self.layer_idx in self.config.full_attention_idx