mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Fix plamo2 model to use rms_norm (#1308)
* Fix plamo2 model to use rms_norm and enable sliding window attention * Fix missing variable * Remove sliding window attention impl. cause it should be done by using RotatingKVCache * Remove unused imports
This commit is contained in:
parent
845cd8c01e
commit
269faa5fa4
@ -2,7 +2,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
|
||||
mamba_enabled: bool = True
|
||||
intermediate_size: int = 13312
|
||||
vocab_size: int = 32000
|
||||
max_position_embeddings: int = 10 * 1024 * 1024
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@ -53,6 +52,16 @@ class RMSNorm(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.astype(mx.float32)
|
||||
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
||||
hidden_states = hidden_states * mx.rsqrt(variance + eps)
|
||||
hidden_states = hidden_states.astype(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
||||
dt_min = 0.001
|
||||
dt_max = 0.1
|
||||
@ -220,8 +229,7 @@ def ssd_chunk_scan_combined(
|
||||
|
||||
|
||||
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
|
||||
batch, seqlen, dim = x.shape
|
||||
width = weight.shape[1]
|
||||
_, seqlen, dim = x.shape
|
||||
state_len = conv_state.shape[-2]
|
||||
x = mx.concatenate([conv_state, x], axis=-2)
|
||||
conv_state = x[:, -state_len:]
|
||||
@ -392,8 +400,8 @@ class Attention(nn.Module):
|
||||
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
||||
|
||||
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None]
|
||||
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None]
|
||||
q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
|
||||
k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
|
||||
|
||||
if cache is not None:
|
||||
q = self.rope(q, offset=cache.offset)
|
||||
@ -556,7 +564,6 @@ class PlamoModel(nn.Module):
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -567,7 +574,7 @@ class Model(nn.Module):
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
config.hidden_size, vocab_size, bias=False
|
||||
config.hidden_size, self.vocab_size, bias=False
|
||||
)
|
||||
|
||||
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
|
||||
|
Loading…
Reference in New Issue
Block a user