mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add support for ibm granite (#758)
* add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * remove unused function * rebase fix * move position emebedding to mask creation * add to tuner and format * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * add support for granite 3-8B config * add gpt_bigcode * add positional embedding condition. * rebase fix * move position emebedding to mask creation * add to tuner and format * refactor mask * remove dropout layers
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_additive_causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -17,9 +17,12 @@ class ModelArgs(BaseModelArgs):
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
@@ -44,11 +47,15 @@ class Attention(nn.Module):
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
@@ -93,11 +100,19 @@ class Attention(nn.Module):
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
@@ -109,7 +124,7 @@ class TransformerBlock(nn.Module):
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
@@ -129,13 +144,6 @@ class TransformerBlock(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@@ -175,10 +183,11 @@ class LlamaModel(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -186,7 +195,11 @@ class Model(nn.Module):
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
|
||||
Reference in New Issue
Block a user