Add support for ibm granite (#758)

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* remove unused function

* rebase fix

* move position emebedding to mask creation

* add to tuner and format

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* add support for granite 3-8B config

* add gpt_bigcode

* add positional embedding condition.

* rebase fix

* move position emebedding to mask creation

* add to tuner and format

* refactor mask

* remove dropout layers
This commit is contained in:
Prince Canuma 2024-05-22 05:16:31 +02:00 committed by GitHub
parent 9fc6efbd90
commit b044ce2acf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 238 additions and 20 deletions

View File

@ -4,6 +4,13 @@ from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
def create_additive_causal_mask(N: int, offset: int = 0):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
return mask * -1e9
class KVCache: class KVCache:
def __init__(self, head_dim, n_kv_heads): def __init__(self, head_dim, n_kv_heads):

View File

@ -0,0 +1,195 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_additive_causal_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
n_embd: int
n_layer: int
n_inner: int
n_head: int
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
multi_query: bool = True
attention_bias: bool = True
mlp_bias: bool = True
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = 1 if self.multi_query else self.n_head
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = dim = args.n_embd
self.n_heads = n_heads = args.n_head
self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head
self.head_dim = head_dim = dim // n_heads
self.kv_dim = n_kv_heads * head_dim
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias)
self.c_proj = nn.Linear(dim, dim, bias=attention_bias)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.c_attn(x)
queries, keys, values = mx.split(
qkv, [self.dim, self.dim + self.kv_dim], axis=-1
)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.n_embd
hidden_dim = args.n_inner
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.n_head
self.n_embd = args.n_embd
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
return out
class GPTBigCodeModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.wte = nn.Embedding(args.vocab_size, args.n_embd)
self.wpe = nn.Embedding(args.n_positions, args.n_embd)
self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)]
self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None,
):
B, L = inputs.shape
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_additive_causal_mask(
hidden_states.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(hidden_states.dtype)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
return self.ln_f(hidden_states)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = GPTBigCodeModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.transformer(inputs, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, create_additive_causal_mask
@dataclass @dataclass
@ -17,9 +17,12 @@ class ModelArgs(BaseModelArgs):
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: int = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self): def __post_init__(self):
if self.num_key_value_heads is None: if self.num_key_value_heads is None:
@ -44,11 +47,15 @@ class Attention(nn.Module):
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5 self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
rope_scale = ( rope_scale = (
1 / args.rope_scaling["factor"] 1 / args.rope_scaling["factor"]
@ -93,11 +100,19 @@ class Attention(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, dim, hidden_dim): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) dim = args.hidden_size
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array: def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
@ -109,7 +124,7 @@ class TransformerBlock(nn.Module):
self.num_attention_heads = args.num_attention_heads self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size) self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm( self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps args.hidden_size, eps=args.rms_norm_eps
@ -129,13 +144,6 @@ class TransformerBlock(nn.Module):
return out return out
def create_additive_causal_mask(N: int, offset: int = 0):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
return mask * -1e9
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -175,10 +183,11 @@ class LlamaModel(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args
self.model_type = args.model_type self.model_type = args.model_type
self.model = LlamaModel(args) self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) if not args.tie_word_embeddings:
self.args = args self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__( def __call__(
self, self,
@ -186,7 +195,11 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out = self.model(inputs, cache) out = self.model(inputs, cache)
return self.lm_head(out) if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights): def sanitize(self, weights):
# Remove unused precomputed rotary freqs # Remove unused precomputed rotary freqs

View File

@ -106,6 +106,9 @@ def linear_to_lora_layers(
if model.model_type == "qwen2_moe": if model.model_type == "qwen2_moe":
keys.add("mlp.gate") keys.add("mlp.gate")
keys.add("mlp.shared_expert_gate") keys.add("mlp.shared_expert_gate")
elif model.model_type == "gpt_bigcode":
keys = set(["attn.c_attn"])
elif model.model_type == "olmo": elif model.model_type == "olmo":
keys = set(["att_proj"]) keys = set(["att_proj"])
elif model.model_type == "openelm": elif model.model_type == "openelm":