Stable lm 2 (#666)

* stable lm 2

* test and lora

* version bump

* merge stable models
This commit is contained in:
Awni Hannun 2024-04-08 14:18:55 -07:00 committed by GitHub
parent 1e2f7f50b6
commit c68aa3c7c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 72 additions and 19 deletions

View File

@ -16,11 +16,27 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
num_hidden_layers: int
num_key_value_heads: int
partial_rotary_factor: float
intermediate_size: int
layer_norm_eps: float
rope_theta: float
use_qkv_bias: bool
partial_rotary_factor: float
layer_norm_eps: float
use_parallel_residual: bool = False
qk_layernorm: bool = False
class LayerNormPerHead(nn.Module):
def __init__(self, head_dim, num_heads, eps):
super().__init__()
self.norms = [
nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads)
]
self.eps = eps
def __call__(self, x):
w = mx.stack([n.weight for n in self.norms])
return w * mx.fast.layer_norm(x, None, None, self.eps)
class Attention(nn.Module):
@ -63,22 +79,31 @@ class Attention(nn.Module):
base=self.rope_theta,
)
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm:
self.q_layernorm = LayerNormPerHead(
self.head_dim, self.num_heads, eps=config.layer_norm_eps
)
self.k_layernorm = LayerNormPerHead(
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
)
def __call__(self, x, mask=None, cache=None):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Extract some shapes
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose(
queries = queries.reshape(B, L, self.num_heads, -1)
keys = keys.reshape(B, L, self.num_key_value_heads, -1)
if self.qk_layernorm:
queries = self.q_layernorm(queries)
keys = self.k_layernorm(keys)
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_key_value_heads, self.head_dim).transpose(
0, 2, 1, 3
)
values = values.reshape(
B, L, self.num_key_value_heads, self.head_dim
).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
@ -120,14 +145,23 @@ class DecoderLayer(nn.Module):
self.self_attn = Attention(config=config)
self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
config.hidden_size,
eps=config.layer_norm_eps,
)
self.use_parallel_residual = config.use_parallel_residual
if not self.use_parallel_residual:
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
config.hidden_size,
eps=config.layer_norm_eps,
)
def __call__(self, x, mask, cache):
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = self.input_layernorm(x)
r, cache = self.self_attn(h, mask, cache)
if self.use_parallel_residual:
out = x + r + self.mlp(h)
else:
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.7.0"
__version__ = "0.8.0"

View File

@ -242,6 +242,25 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
# StableLM 2
args = stablelm.ModelArgs(
model_type="stablelm",
vocab_size=10000,
hidden_size=512,
num_attention_heads=8,
num_hidden_layers=4,
num_key_value_heads=2,
partial_rotary_factor=0.25,
intermediate_size=1024,
layer_norm_eps=1e-5,
rope_theta=10000,
use_qkv_bias=True,
)
model = stablelm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_starcoder2(self):
from mlx_lm.models import starcoder2