mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Stable lm 2 (#666)
* stable lm 2 * test and lora * version bump * merge stable models
This commit is contained in:
parent
1e2f7f50b6
commit
c68aa3c7c3
@ -16,11 +16,27 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
num_hidden_layers: int
|
num_hidden_layers: int
|
||||||
num_key_value_heads: int
|
num_key_value_heads: int
|
||||||
partial_rotary_factor: float
|
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
layer_norm_eps: float
|
|
||||||
rope_theta: float
|
rope_theta: float
|
||||||
use_qkv_bias: bool
|
use_qkv_bias: bool
|
||||||
|
partial_rotary_factor: float
|
||||||
|
layer_norm_eps: float
|
||||||
|
use_parallel_residual: bool = False
|
||||||
|
qk_layernorm: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormPerHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, head_dim, num_heads, eps):
|
||||||
|
super().__init__()
|
||||||
|
self.norms = [
|
||||||
|
nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads)
|
||||||
|
]
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
w = mx.stack([n.weight for n in self.norms])
|
||||||
|
return w * mx.fast.layer_norm(x, None, None, self.eps)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -63,22 +79,31 @@ class Attention(nn.Module):
|
|||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.qk_layernorm = config.qk_layernorm
|
||||||
|
if self.qk_layernorm:
|
||||||
|
self.q_layernorm = LayerNormPerHead(
|
||||||
|
self.head_dim, self.num_heads, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.k_layernorm = LayerNormPerHead(
|
||||||
|
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x, mask=None, cache=None):
|
def __call__(self, x, mask=None, cache=None):
|
||||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
# Extract some shapes
|
# Extract some shapes
|
||||||
B, L, D = queries.shape
|
B, L, D = queries.shape
|
||||||
|
|
||||||
# Prepare the queries, keys and values for the attention computation
|
queries = queries.reshape(B, L, self.num_heads, -1)
|
||||||
queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose(
|
keys = keys.reshape(B, L, self.num_key_value_heads, -1)
|
||||||
|
if self.qk_layernorm:
|
||||||
|
queries = self.q_layernorm(queries)
|
||||||
|
keys = self.k_layernorm(keys)
|
||||||
|
queries = queries.transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||||
0, 2, 1, 3
|
0, 2, 1, 3
|
||||||
)
|
)
|
||||||
keys = keys.reshape(B, L, self.num_key_value_heads, self.head_dim).transpose(
|
|
||||||
0, 2, 1, 3
|
|
||||||
)
|
|
||||||
values = values.reshape(
|
|
||||||
B, L, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(0, 2, 1, 3)
|
|
||||||
|
|
||||||
# Add RoPE to the queries and keys and combine them with the cache
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
@ -120,17 +145,26 @@ class DecoderLayer(nn.Module):
|
|||||||
self.self_attn = Attention(config=config)
|
self.self_attn = Attention(config=config)
|
||||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||||
self.input_layernorm = nn.LayerNorm(
|
self.input_layernorm = nn.LayerNorm(
|
||||||
config.hidden_size, eps=config.layer_norm_eps
|
config.hidden_size,
|
||||||
)
|
eps=config.layer_norm_eps,
|
||||||
self.post_attention_layernorm = nn.LayerNorm(
|
|
||||||
config.hidden_size, eps=config.layer_norm_eps
|
|
||||||
)
|
)
|
||||||
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
|
if not self.use_parallel_residual:
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x, mask, cache):
|
def __call__(self, x, mask, cache):
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
h = self.input_layernorm(x)
|
||||||
h = x + r
|
r, cache = self.self_attn(h, mask, cache)
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
|
||||||
out = h + r
|
if self.use_parallel_residual:
|
||||||
|
out = x + r + self.mlp(h)
|
||||||
|
else:
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
return out, cache
|
return out, cache
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.7.0"
|
__version__ = "0.8.0"
|
||||||
|
@ -242,6 +242,25 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# StableLM 2
|
||||||
|
args = stablelm.ModelArgs(
|
||||||
|
model_type="stablelm",
|
||||||
|
vocab_size=10000,
|
||||||
|
hidden_size=512,
|
||||||
|
num_attention_heads=8,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
partial_rotary_factor=0.25,
|
||||||
|
intermediate_size=1024,
|
||||||
|
layer_norm_eps=1e-5,
|
||||||
|
rope_theta=10000,
|
||||||
|
use_qkv_bias=True,
|
||||||
|
)
|
||||||
|
model = stablelm.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
def test_starcoder2(self):
|
def test_starcoder2(self):
|
||||||
from mlx_lm.models import starcoder2
|
from mlx_lm.models import starcoder2
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user