mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Starcoder2: Update config and change GQA to use repeat (#520)
* update config * change gqa to use repeat instead of concante * contribution
This commit is contained in:
@@ -18,20 +18,12 @@ class ModelArgs(BaseModelArgs):
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int = None
|
||||
max_position_embeddings: int = 16384
|
||||
norm_eps: float = None
|
||||
rms_norm_eps: float = 1e-5
|
||||
norm_epsilon: float = 1e-5
|
||||
norm_type: str = "layer_norm"
|
||||
vocab_size: int = 49152
|
||||
rope_theta: float = 100000
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.norm_eps is None:
|
||||
self.norm_eps = self.rms_norm_eps
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
@@ -68,12 +60,9 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
@@ -111,9 +100,9 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
args.hidden_size, eps=args.norm_epsilon
|
||||
)
|
||||
self.args = args
|
||||
|
||||
@@ -141,7 +130,7 @@ class Starcoder2Model(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user