From 3fdf85e79d4c97ed614991dd3277f93af4c2cc0a Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 3 Mar 2024 15:12:03 +0100 Subject: [PATCH] Starcoder2: Update config and change GQA to use repeat (#520) * update config * change gqa to use repeat instead of concante * contribution --- ACKNOWLEDGMENTS.md | 1 + llms/mlx_lm/models/starcoder2.py | 23 ++++++----------------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 3bca9bd3..f9528f38 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -13,3 +13,4 @@ MLX Examples was developed with contributions from the following individuals: - Gabrijel Boduljak: Implemented `CLIP`. - Markus Enzweiler: Added the `cvae` examples. - Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example +- Prince Canuma: Helped add support for `Starcoder2` models. diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 27a53af9..c0e32412 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -18,20 +18,12 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int num_key_value_heads: int = None max_position_embeddings: int = 16384 - norm_eps: float = None - rms_norm_eps: float = 1e-5 + norm_epsilon: float = 1e-5 norm_type: str = "layer_norm" vocab_size: int = 49152 rope_theta: float = 100000 tie_word_embeddings: bool = True - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.norm_eps is None: - self.norm_eps = self.rms_norm_eps - class Attention(nn.Module): def __init__(self, args: ModelArgs): @@ -68,12 +60,9 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache @@ -111,9 +100,9 @@ class TransformerBlock(nn.Module): self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) + self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) self.post_attention_layernorm = LayerNorm( - args.hidden_size, eps=args.rms_norm_eps + args.hidden_size, eps=args.norm_epsilon ) self.args = args @@ -141,7 +130,7 @@ class Starcoder2Model(nn.Module): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) def __call__( self,