Fix plamo2 model to use rms_norm (#1308)

* Fix plamo2 model to use rms_norm and enable sliding window attention

* Fix missing variable

* Remove sliding window attention impl. cause it should be done by using RotatingKVCache

* Remove unused imports
This commit is contained in:
Shunta Saito 2025-03-03 23:12:02 +09:00 committed by GitHub
parent 845cd8c01e
commit 269faa5fa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
import math
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
mamba_enabled: bool = True
intermediate_size: int = 13312
vocab_size: int = 32000
max_position_embeddings: int = 10 * 1024 * 1024
class RMSNorm(nn.Module):
@ -53,6 +52,16 @@ class RMSNorm(nn.Module):
)
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(mx.float32)
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + eps)
hidden_states = hidden_states.astype(input_dtype)
return hidden_states
def get_initial_dt_bias(num_heads: int) -> mx.array:
dt_min = 0.001
dt_max = 0.1
@ -220,8 +229,7 @@ def ssd_chunk_scan_combined(
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
batch, seqlen, dim = x.shape
width = weight.shape[1]
_, seqlen, dim = x.shape
state_len = conv_state.shape[-2]
x = mx.concatenate([conv_state, x], axis=-2)
conv_state = x[:, -state_len:]
@ -392,8 +400,8 @@ class Attention(nn.Module):
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None]
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None]
q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
if cache is not None:
q = self.rope(q, offset=cache.offset)
@ -556,7 +564,6 @@ class PlamoModel(nn.Module):
class Model(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
@ -567,7 +574,7 @@ class Model(nn.Module):
if not config.tie_word_embeddings:
self.lm_head: nn.Module = nn.Linear(
config.hidden_size, vocab_size, bias=False
config.hidden_size, self.vocab_size, bias=False
)
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: