mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add support for ibm granite (#758)
* add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * remove unused function * rebase fix * move position emebedding to mask creation * add to tuner and format * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * rebase fix * move position emebedding to mask creation * add to tuner and format * refactor mask * remove dropout layers
This commit is contained in:
parent
9fc6efbd90
commit
b044ce2acf
@ -4,6 +4,13 @@ from dataclasses import dataclass
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
class KVCache:
|
||||
|
||||
def __init__(self, head_dim, n_kv_heads):
|
||||
|
195
llms/mlx_lm/models/gpt_bigcode.py
Normal file
195
llms/mlx_lm/models/gpt_bigcode.py
Normal file
@ -0,0 +1,195 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_additive_causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
n_embd: int
|
||||
n_layer: int
|
||||
n_inner: int
|
||||
n_head: int
|
||||
n_positions: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
multi_query: bool = True
|
||||
attention_bias: bool = True
|
||||
mlp_bias: bool = True
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = 1 if self.multi_query else self.n_head
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim = args.n_embd
|
||||
self.n_heads = n_heads = args.n_head
|
||||
self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head
|
||||
|
||||
self.head_dim = head_dim = dim // n_heads
|
||||
|
||||
self.kv_dim = n_kv_heads * head_dim
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias)
|
||||
self.c_proj = nn.Linear(dim, dim, bias=attention_bias)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.c_attn(x)
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [self.dim, self.dim + self.kv_dim], axis=-1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.n_embd
|
||||
hidden_dim = args.n_inner
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.c_proj(nn.gelu(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_head = args.n_head
|
||||
self.n_embd = args.n_embd
|
||||
self.attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.ln_2(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GPTBigCodeModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
assert self.vocab_size > 0
|
||||
self.wte = nn.Embedding(args.vocab_size, args.n_embd)
|
||||
self.wpe = nn.Embedding(args.n_positions, args.n_embd)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)]
|
||||
self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
B, L = inputs.shape
|
||||
|
||||
hidden_states = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if hidden_states.shape[1] > 1:
|
||||
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
mask = create_additive_causal_mask(
|
||||
hidden_states.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(hidden_states.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
return self.ln_f(hidden_states)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.transformer = GPTBigCodeModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.transformer.wte.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.n_embd // self.args.n_head
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_additive_causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -17,9 +17,12 @@ class ModelArgs(BaseModelArgs):
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
@ -44,11 +47,15 @@ class Attention(nn.Module):
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
@ -93,11 +100,19 @@ class Attention(nn.Module):
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
@ -109,7 +124,7 @@ class TransformerBlock(nn.Module):
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
@ -129,13 +144,6 @@ class TransformerBlock(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@ -175,10 +183,11 @@ class LlamaModel(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -186,7 +195,11 @@ class Model(nn.Module):
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
|
@ -106,6 +106,9 @@ def linear_to_lora_layers(
|
||||
if model.model_type == "qwen2_moe":
|
||||
keys.add("mlp.gate")
|
||||
keys.add("mlp.shared_expert_gate")
|
||||
|
||||
elif model.model_type == "gpt_bigcode":
|
||||
keys = set(["attn.c_attn"])
|
||||
elif model.model_type == "olmo":
|
||||
keys = set(["att_proj"])
|
||||
elif model.model_type == "openelm":
|
||||
|
Loading…
Reference in New Issue
Block a user