From c68aa3c7c381b39a563a9102c5925f9d3b1523a8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 8 Apr 2024 14:18:55 -0700 Subject: [PATCH] Stable lm 2 (#666) * stable lm 2 * test and lora * version bump * merge stable models --- llms/mlx_lm/models/stablelm.py | 70 +++++++++++++++++++++++++--------- llms/mlx_lm/version.py | 2 +- llms/tests/test_models.py | 19 +++++++++ 3 files changed, 72 insertions(+), 19 deletions(-) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index f685c76d..47af5295 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -16,11 +16,27 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int num_hidden_layers: int num_key_value_heads: int - partial_rotary_factor: float intermediate_size: int - layer_norm_eps: float rope_theta: float use_qkv_bias: bool + partial_rotary_factor: float + layer_norm_eps: float + use_parallel_residual: bool = False + qk_layernorm: bool = False + + +class LayerNormPerHead(nn.Module): + + def __init__(self, head_dim, num_heads, eps): + super().__init__() + self.norms = [ + nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads) + ] + self.eps = eps + + def __call__(self, x): + w = mx.stack([n.weight for n in self.norms]) + return w * mx.fast.layer_norm(x, None, None, self.eps) class Attention(nn.Module): @@ -63,22 +79,31 @@ class Attention(nn.Module): base=self.rope_theta, ) + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = LayerNormPerHead( + self.head_dim, self.num_heads, eps=config.layer_norm_eps + ) + self.k_layernorm = LayerNormPerHead( + self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps + ) + def __call__(self, x, mask=None, cache=None): queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Extract some shapes B, L, D = queries.shape - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose( + queries = queries.reshape(B, L, self.num_heads, -1) + keys = keys.reshape(B, L, self.num_key_value_heads, -1) + if self.qk_layernorm: + queries = self.q_layernorm(queries) + keys = self.k_layernorm(keys) + queries = queries.transpose(0, 2, 1, 3) + keys = keys.transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( 0, 2, 1, 3 ) - keys = keys.reshape(B, L, self.num_key_value_heads, self.head_dim).transpose( - 0, 2, 1, 3 - ) - values = values.reshape( - B, L, self.num_key_value_heads, self.head_dim - ).transpose(0, 2, 1, 3) # Add RoPE to the queries and keys and combine them with the cache if cache is not None: @@ -120,17 +145,26 @@ class DecoderLayer(nn.Module): self.self_attn = Attention(config=config) self.mlp = MLP(config.hidden_size, config.intermediate_size) self.input_layernorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.post_attention_layernorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps + config.hidden_size, + eps=config.layer_norm_eps, ) + self.use_parallel_residual = config.use_parallel_residual + if not self.use_parallel_residual: + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ) def __call__(self, x, mask, cache): - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r + h = self.input_layernorm(x) + r, cache = self.self_attn(h, mask, cache) + + if self.use_parallel_residual: + out = x + r + self.mlp(h) + else: + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r return out, cache diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 3d13212d..e339bd95 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.7.0" +__version__ = "0.8.0" diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index effeab53..57fab58d 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -242,6 +242,25 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + # StableLM 2 + args = stablelm.ModelArgs( + model_type="stablelm", + vocab_size=10000, + hidden_size=512, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + partial_rotary_factor=0.25, + intermediate_size=1024, + layer_norm_eps=1e-5, + rope_theta=10000, + use_qkv_bias=True, + ) + model = stablelm.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_starcoder2(self): from mlx_lm.models import starcoder2