Fix plamo2 model to use rms_norm (#1308)

* Fix plamo2 model to use rms_norm and enable sliding window attention

* Fix missing variable

* Remove sliding window attention impl. cause it should be done by using RotatingKVCache

* Remove unused imports
This commit is contained in:
Shunta Saito 2025-03-03 23:12:02 +09:00 committed by GitHub
parent 845cd8c01e
commit 269faa5fa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
mamba_enabled: bool = True mamba_enabled: bool = True
intermediate_size: int = 13312 intermediate_size: int = 13312
vocab_size: int = 32000 vocab_size: int = 32000
max_position_embeddings: int = 10 * 1024 * 1024
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -53,6 +52,16 @@ class RMSNorm(nn.Module):
) )
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(mx.float32)
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + eps)
hidden_states = hidden_states.astype(input_dtype)
return hidden_states
def get_initial_dt_bias(num_heads: int) -> mx.array: def get_initial_dt_bias(num_heads: int) -> mx.array:
dt_min = 0.001 dt_min = 0.001
dt_max = 0.1 dt_max = 0.1
@ -220,8 +229,7 @@ def ssd_chunk_scan_combined(
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
batch, seqlen, dim = x.shape _, seqlen, dim = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-2] state_len = conv_state.shape[-2]
x = mx.concatenate([conv_state, x], axis=-2) x = mx.concatenate([conv_state, x], axis=-2)
conv_state = x[:, -state_len:] conv_state = x[:, -state_len:]
@ -392,8 +400,8 @@ class Attention(nn.Module):
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
if cache is not None: if cache is not None:
q = self.rope(q, offset=cache.offset) q = self.rope(q, offset=cache.offset)
@ -556,7 +564,6 @@ class PlamoModel(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -567,7 +574,7 @@ class Model(nn.Module):
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
self.lm_head: nn.Module = nn.Linear( self.lm_head: nn.Module = nn.Linear(
config.hidden_size, vocab_size, bias=False config.hidden_size, self.vocab_size, bias=False
) )
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: