From 81e2a80026782b74fdf79461fabde90277e536e0 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Sun, 3 Mar 2024 04:39:23 +0100 Subject: [PATCH] Add Starcoder 2 (#502) * Add Starcoder2 model and update utils.py * Refactor model arguments and modules in starcoder2.py * Refactor FeedForward class to MLP in starcoder2.py * Fix typo * pre-commit * Refactor starcoder2.py: Update model arguments and modules * Fix LM head and MLP layers * Rename input layer norm * Update bias in linear layers * Refactor token embeddings in Starcoder2Model * Rename to standard HF attention layer name * Add LayerNorm * Add transposed token embeddings (like in Gemma) * Refactor MLP and TransformerBlock classes * Add tie_word_embeddings option to ModelArgs and update Model implementation * Add conditional check for tying word embeddings in Starcoder2Model * Fix bias in lm_head linear layer * Remove unused LayerNorm in stablelm * Update transformers dependency to use GitHub repository * fix lm head bug, revert transformer req * Update RoPE initialization in Attention class --------- Co-authored-by: Awni Hannun --- llms/README.md | 2 +- llms/mlx_lm/models/stablelm.py | 1 - llms/mlx_lm/models/starcoder2.py | 189 +++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 1 + 4 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 llms/mlx_lm/models/starcoder2.py diff --git a/llms/README.md b/llms/README.md index a15d00c8..27348e04 100644 --- a/llms/README.md +++ b/llms/README.md @@ -46,7 +46,7 @@ You can convert models in the Python API with: ```python from mlx_lm import convert -upload_repo = "mistralai/Mistral-7B-Instruct-v0.1" +upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit" convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo) ``` diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 0f2c4f03..5fbca3ae 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -128,7 +128,6 @@ class DecoderLayer(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.self_attn = Attention(config=config) - self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = MLP(config.hidden_size, config.intermediate_size) self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = LayerNorm( diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py new file mode 100644 index 00000000..aeebfc96 --- /dev/null +++ b/llms/mlx_lm/models/starcoder2.py @@ -0,0 +1,189 @@ +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .layers import LayerNorm + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int = None + max_position_embeddings: int = 16384 + norm_eps: float = None + rms_norm_eps: float = 1e-5 + norm_type: str = "layer_norm" + vocab_size: int = 49152 + rope_theta: float = 100000 + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.norm_eps is None: + self.norm_eps = self.rms_norm_eps + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.repeats = self.n_heads // self.n_kv_heads + + head_dim = args.hidden_size // args.num_attention_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True) + self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + if self.repeats > 1: + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.c_fc = nn.Linear(dim, hidden_dim, bias=True) + self.c_proj = nn.Linear(hidden_dim, dim, bias=True) + + def __call__(self, x): + return self.c_proj(nn.gelu(self.c_fc(x))) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.hidden_size = args.hidden_size + self.n_heads = args.num_attention_heads + + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = LayerNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class Starcoder2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model = Starcoder2Model(args) + # This is for 15B starcoder2 since it doesn't tie word embeddings + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + if not self.model.args.tie_word_embeddings: + return self.lm_head(out), cache + else: + out = out @ self.model.embed_tokens.weight.T + return out, cache + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 0fd688de..bfa5cdf9 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -32,6 +32,7 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int): "stablelm", "qwen2", "gemma", + "starcoder2", ]: check_lora_layers(len(model.model.layers))