From b044ce2acf91b10f50e0b71ec206d9aa088f5a10 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 22 May 2024 05:16:31 +0200 Subject: [PATCH] Add support for ibm granite (#758) * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * remove unused function * rebase fix * move position emebedding to mask creation * add to tuner and format * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * rebase fix * move position emebedding to mask creation * add to tuner and format * refactor mask * remove dropout layers --- llms/mlx_lm/models/base.py | 7 ++ llms/mlx_lm/models/gpt_bigcode.py | 195 ++++++++++++++++++++++++++++++ llms/mlx_lm/models/llama.py | 53 +++++--- llms/mlx_lm/tuner/utils.py | 3 + 4 files changed, 238 insertions(+), 20 deletions(-) create mode 100644 llms/mlx_lm/models/gpt_bigcode.py diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 1e184294..b98a1909 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -4,6 +4,13 @@ from dataclasses import dataclass import mlx.core as mx +def create_additive_causal_mask(N: int, offset: int = 0): + rinds = mx.arange(offset + N) + linds = mx.arange(offset, offset + N) if offset else rinds + mask = linds[:, None] < rinds[None] + return mask * -1e9 + + class KVCache: def __init__(self, head_dim, n_kv_heads): diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py new file mode 100644 index 00000000..20af3d0b --- /dev/null +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -0,0 +1,195 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .base import BaseModelArgs, create_additive_causal_mask + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + n_embd: int + n_layer: int + n_inner: int + n_head: int + n_positions: int + layer_norm_epsilon: float + vocab_size: int + num_key_value_heads: int = None + multi_query: bool = True + attention_bias: bool = True + mlp_bias: bool = True + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = 1 if self.multi_query else self.n_head + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.dim = dim = args.n_embd + self.n_heads = n_heads = args.n_head + self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head + + self.head_dim = head_dim = dim // n_heads + + self.kv_dim = n_kv_heads * head_dim + + self.scale = head_dim**-0.5 + + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias) + self.c_proj = nn.Linear(dim, dim, bias=attention_bias) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + qkv = self.c_attn(x) + queries, keys, values = mx.split( + qkv, [self.dim, self.dim + self.kv_dim], axis=-1 + ) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.c_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.n_embd + hidden_dim = args.n_inner + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias) + self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + + def __call__(self, x) -> mx.array: + return self.c_proj(nn.gelu(self.c_fc(x))) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_head = args.n_head + self.n_embd = args.n_embd + self.attn = Attention(args) + self.mlp = MLP(args) + self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon) + self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r = self.attn(self.ln_1(x), mask, cache) + h = x + r + r = self.mlp(self.ln_2(h)) + out = h + r + return out + + +class GPTBigCodeModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + assert self.vocab_size > 0 + self.wte = nn.Embedding(args.vocab_size, args.n_embd) + self.wpe = nn.Embedding(args.n_positions, args.n_embd) + self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)] + self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + B, L = inputs.shape + + hidden_states = self.wte(inputs) + + mask = None + if hidden_states.shape[1] > 1: + + position_ids = mx.array(np.arange(L)) + hidden_states += self.wpe(position_ids) + + mask = create_additive_causal_mask( + hidden_states.shape[1], cache[0].offset if cache is not None else 0 + ) + mask = mask.astype(hidden_states.dtype) + + if cache is None: + cache = [None] * len(self.h) + + for layer, c in zip(self.h, cache): + hidden_states = layer(hidden_states, mask, cache=c) + + return self.ln_f(hidden_states) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.transformer = GPTBigCodeModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.transformer(inputs, cache) + if self.args.tie_word_embeddings: + out = self.transformer.wte.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.transformer.h + + @property + def head_dim(self): + return self.args.n_embd // self.args.n_head + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index b49d0419..55a2b5db 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_additive_causal_mask @dataclass @@ -17,9 +17,12 @@ class ModelArgs(BaseModelArgs): rms_norm_eps: float vocab_size: int num_key_value_heads: int = None + attention_bias: bool = False + mlp_bias: bool = False rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True def __post_init__(self): if self.num_key_value_heads is None: @@ -44,11 +47,15 @@ class Attention(nn.Module): head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) rope_scale = ( 1 / args.rope_scaling["factor"] @@ -93,11 +100,19 @@ class Attention(nn.Module): class MLP(nn.Module): - def __init__(self, dim, hidden_dim): + def __init__(self, args: ModelArgs): super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + dim = args.hidden_size + hidden_dim = args.intermediate_size + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) def __call__(self, x) -> mx.array: return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) @@ -109,7 +124,7 @@ class TransformerBlock(nn.Module): self.num_attention_heads = args.num_attention_heads self.hidden_size = args.hidden_size self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.mlp = MLP(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm( args.hidden_size, eps=args.rms_norm_eps @@ -129,13 +144,6 @@ class TransformerBlock(nn.Module): return out -def create_additive_causal_mask(N: int, offset: int = 0): - rinds = mx.arange(offset + N) - linds = mx.arange(offset, offset + N) if offset else rinds - mask = linds[:, None] < rinds[None] - return mask * -1e9 - - class LlamaModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -175,10 +183,11 @@ class LlamaModel(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.args = args self.model_type = args.model_type self.model = LlamaModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - self.args = args + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -186,7 +195,11 @@ class Model(nn.Module): cache=None, ): out = self.model(inputs, cache) - return self.lm_head(out) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out def sanitize(self, weights): # Remove unused precomputed rotary freqs diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index cc085d78..83db9f51 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -106,6 +106,9 @@ def linear_to_lora_layers( if model.model_type == "qwen2_moe": keys.add("mlp.gate") keys.add("mlp.shared_expert_gate") + + elif model.model_type == "gpt_bigcode": + keys = set(["attn.c_attn"]) elif model.model_type == "olmo": keys = set(["att_proj"]) elif model.model_type == "openelm":