From 08a8dd2507fd78c7159058acd6d7da10b318ef09 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Fri, 28 Feb 2025 01:17:35 +0900 Subject: [PATCH] Fix plamo2 model to use rms_norm and enable sliding window attention --- llms/mlx_lm/models/plamo2.py | 41 +++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index 1d8215dd..e9410e0f 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -53,6 +53,16 @@ class RMSNorm(nn.Module): ) +def _rms_norm(hidden_states: mx.array, eps: float, offset: float = 1.0) -> mx.array: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.astype(mx.float32) + variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + eps) + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + def get_initial_dt_bias(num_heads: int) -> mx.array: dt_min = 0.001 dt_max = 0.1 @@ -344,6 +354,15 @@ class Mamba(nn.Module): return y +def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array: + max_len = max(q_len, kv_len) + mask = mx.tril( + mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore + k=window_size, + ) + return mask[-q_len:, -kv_len:] + + class Attention(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() @@ -392,8 +411,8 @@ class Attention(nn.Module): k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) - q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] - k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] + q = _rms_norm(q, 1e-6) * self.q_weight[:, None] + k = _rms_norm(k, 1e-6) * self.k_weight[:, None] if cache is not None: q = self.rope(q, offset=cache.offset) @@ -403,6 +422,23 @@ class Attention(nn.Module): q = self.rope(q) k = self.rope(k) + if mask is not None: + if mask.dtype == bool: + mask = mx.where(mask, mx.array(0.0, dtype=mx.float16), float("-inf")) + if len(mask.shape) == 2: + mask = mask[None, None] + assert len(mask.shape) == 4 + + m_swa = swa_mask( + q.shape[2], + k.shape[2], + self.config.attention_window_size, + ) + # `generate` function creates attention mask that does not consider sliding window + m_swa = m_swa[None, None] + mask = mask[:, :, -q.shape[2] :, -k.shape[2] :] + mask = mx.where(m_swa, mask, float("-inf")) + output = mx.fast.scaled_dot_product_attention( q, k, @@ -556,7 +592,6 @@ class PlamoModel(nn.Module): class Model(nn.Module): - def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config