mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Fix plamo2 model to use rms_norm (#1308)
* Fix plamo2 model to use rms_norm and enable sliding window attention * Fix missing variable * Remove sliding window attention impl. cause it should be done by using RotatingKVCache * Remove unused imports
This commit is contained in:
parent
845cd8c01e
commit
269faa5fa4
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
mamba_enabled: bool = True
|
mamba_enabled: bool = True
|
||||||
intermediate_size: int = 13312
|
intermediate_size: int = 13312
|
||||||
vocab_size: int = 32000
|
vocab_size: int = 32000
|
||||||
max_position_embeddings: int = 10 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -53,6 +52,16 @@ class RMSNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.astype(mx.float32)
|
||||||
|
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
||||||
|
hidden_states = hidden_states * mx.rsqrt(variance + eps)
|
||||||
|
hidden_states = hidden_states.astype(input_dtype)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
||||||
dt_min = 0.001
|
dt_min = 0.001
|
||||||
dt_max = 0.1
|
dt_max = 0.1
|
||||||
@ -220,8 +229,7 @@ def ssd_chunk_scan_combined(
|
|||||||
|
|
||||||
|
|
||||||
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
|
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
|
||||||
batch, seqlen, dim = x.shape
|
_, seqlen, dim = x.shape
|
||||||
width = weight.shape[1]
|
|
||||||
state_len = conv_state.shape[-2]
|
state_len = conv_state.shape[-2]
|
||||||
x = mx.concatenate([conv_state, x], axis=-2)
|
x = mx.concatenate([conv_state, x], axis=-2)
|
||||||
conv_state = x[:, -state_len:]
|
conv_state = x[:, -state_len:]
|
||||||
@ -392,8 +400,8 @@ class Attention(nn.Module):
|
|||||||
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||||
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None]
|
q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
|
||||||
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None]
|
k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
q = self.rope(q, offset=cache.offset)
|
q = self.rope(q, offset=cache.offset)
|
||||||
@ -556,7 +564,6 @@ class PlamoModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -567,7 +574,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
if not config.tie_word_embeddings:
|
if not config.tie_word_embeddings:
|
||||||
self.lm_head: nn.Module = nn.Linear(
|
self.lm_head: nn.Module = nn.Linear(
|
||||||
config.hidden_size, vocab_size, bias=False
|
config.hidden_size, self.vocab_size, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
|
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
|
||||||
|
Loading…
Reference in New Issue
Block a user