Fix plamo2 model to use rms_norm and enable sliding window attention

This commit is contained in:
Shunta Saito 2025-02-28 01:17:35 +09:00
parent 00a7379070
commit 08a8dd2507

View File

@ -53,6 +53,16 @@ class RMSNorm(nn.Module):
) )
def _rms_norm(hidden_states: mx.array, eps: float, offset: float = 1.0) -> mx.array:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(mx.float32)
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + eps)
hidden_states = hidden_states.astype(input_dtype)
return hidden_states
def get_initial_dt_bias(num_heads: int) -> mx.array: def get_initial_dt_bias(num_heads: int) -> mx.array:
dt_min = 0.001 dt_min = 0.001
dt_max = 0.1 dt_max = 0.1
@ -344,6 +354,15 @@ class Mamba(nn.Module):
return y return y
def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
max_len = max(q_len, kv_len)
mask = mx.tril(
mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore
k=window_size,
)
return mask[-q_len:, -kv_len:]
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__() super().__init__()
@ -392,8 +411,8 @@ class Attention(nn.Module):
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
if cache is not None: if cache is not None:
q = self.rope(q, offset=cache.offset) q = self.rope(q, offset=cache.offset)
@ -403,6 +422,23 @@ class Attention(nn.Module):
q = self.rope(q) q = self.rope(q)
k = self.rope(k) k = self.rope(k)
if mask is not None:
if mask.dtype == bool:
mask = mx.where(mask, mx.array(0.0, dtype=mx.float16), float("-inf"))
if len(mask.shape) == 2:
mask = mask[None, None]
assert len(mask.shape) == 4
m_swa = swa_mask(
q.shape[2],
k.shape[2],
self.config.attention_window_size,
)
# `generate` function creates attention mask that does not consider sliding window
m_swa = m_swa[None, None]
mask = mask[:, :, -q.shape[2] :, -k.shape[2] :]
mask = mx.where(m_swa, mask, float("-inf"))
output = mx.fast.scaled_dot_product_attention( output = mx.fast.scaled_dot_product_attention(
q, q,
k, k,
@ -556,7 +592,6 @@ class PlamoModel(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__() super().__init__()
self.config = config self.config = config