mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 17:36:45 +08:00
Fix plamo2 model to use rms_norm and enable sliding window attention
This commit is contained in:
parent
00a7379070
commit
08a8dd2507
@ -53,6 +53,16 @@ class RMSNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rms_norm(hidden_states: mx.array, eps: float, offset: float = 1.0) -> mx.array:
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.astype(mx.float32)
|
||||||
|
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
||||||
|
hidden_states = hidden_states * mx.rsqrt(variance + eps)
|
||||||
|
hidden_states = hidden_states.astype(input_dtype)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
||||||
dt_min = 0.001
|
dt_min = 0.001
|
||||||
dt_max = 0.1
|
dt_max = 0.1
|
||||||
@ -344,6 +354,15 @@ class Mamba(nn.Module):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
|
||||||
|
max_len = max(q_len, kv_len)
|
||||||
|
mask = mx.tril(
|
||||||
|
mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore
|
||||||
|
k=window_size,
|
||||||
|
)
|
||||||
|
return mask[-q_len:, -kv_len:]
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -392,8 +411,8 @@ class Attention(nn.Module):
|
|||||||
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||||
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None]
|
q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
|
||||||
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None]
|
k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
q = self.rope(q, offset=cache.offset)
|
q = self.rope(q, offset=cache.offset)
|
||||||
@ -403,6 +422,23 @@ class Attention(nn.Module):
|
|||||||
q = self.rope(q)
|
q = self.rope(q)
|
||||||
k = self.rope(k)
|
k = self.rope(k)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask.dtype == bool:
|
||||||
|
mask = mx.where(mask, mx.array(0.0, dtype=mx.float16), float("-inf"))
|
||||||
|
if len(mask.shape) == 2:
|
||||||
|
mask = mask[None, None]
|
||||||
|
assert len(mask.shape) == 4
|
||||||
|
|
||||||
|
m_swa = swa_mask(
|
||||||
|
q.shape[2],
|
||||||
|
k.shape[2],
|
||||||
|
self.config.attention_window_size,
|
||||||
|
)
|
||||||
|
# `generate` function creates attention mask that does not consider sliding window
|
||||||
|
m_swa = m_swa[None, None]
|
||||||
|
mask = mask[:, :, -q.shape[2] :, -k.shape[2] :]
|
||||||
|
mask = mx.where(m_swa, mask, float("-inf"))
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -556,7 +592,6 @@ class PlamoModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
Loading…
Reference in New Issue
Block a user