mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Stable lm 2 (#666)
* stable lm 2 * test and lora * version bump * merge stable models
This commit is contained in:
parent
1e2f7f50b6
commit
c68aa3c7c3
@ -16,11 +16,27 @@ class ModelArgs(BaseModelArgs):
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
partial_rotary_factor: float
|
||||
intermediate_size: int
|
||||
layer_norm_eps: float
|
||||
rope_theta: float
|
||||
use_qkv_bias: bool
|
||||
partial_rotary_factor: float
|
||||
layer_norm_eps: float
|
||||
use_parallel_residual: bool = False
|
||||
qk_layernorm: bool = False
|
||||
|
||||
|
||||
class LayerNormPerHead(nn.Module):
|
||||
|
||||
def __init__(self, head_dim, num_heads, eps):
|
||||
super().__init__()
|
||||
self.norms = [
|
||||
nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads)
|
||||
]
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
w = mx.stack([n.weight for n in self.norms])
|
||||
return w * mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -63,22 +79,31 @@ class Attention(nn.Module):
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
self.qk_layernorm = config.qk_layernorm
|
||||
if self.qk_layernorm:
|
||||
self.q_layernorm = LayerNormPerHead(
|
||||
self.head_dim, self.num_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
self.k_layernorm = LayerNormPerHead(
|
||||
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Extract some shapes
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose(
|
||||
queries = queries.reshape(B, L, self.num_heads, -1)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1)
|
||||
if self.qk_layernorm:
|
||||
queries = self.q_layernorm(queries)
|
||||
keys = self.k_layernorm(keys)
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, self.head_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
values = values.reshape(
|
||||
B, L, self.num_key_value_heads, self.head_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
@ -120,14 +145,23 @@ class DecoderLayer(nn.Module):
|
||||
self.self_attn = Attention(config=config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
if not self.use_parallel_residual:
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = self.input_layernorm(x)
|
||||
r, cache = self.self_attn(h, mask, cache)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
out = x + r + self.mlp(h)
|
||||
else:
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.7.0"
|
||||
__version__ = "0.8.0"
|
||||
|
@ -242,6 +242,25 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
# StableLM 2
|
||||
args = stablelm.ModelArgs(
|
||||
model_type="stablelm",
|
||||
vocab_size=10000,
|
||||
hidden_size=512,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
partial_rotary_factor=0.25,
|
||||
intermediate_size=1024,
|
||||
layer_norm_eps=1e-5,
|
||||
rope_theta=10000,
|
||||
use_qkv_bias=True,
|
||||
)
|
||||
model = stablelm.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_starcoder2(self):
|
||||
from mlx_lm.models import starcoder2
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user