mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Starcoder2: Update config and change GQA to use repeat (#520)
* update config * change gqa to use repeat instead of concante * contribution
This commit is contained in:
parent
1e3daea3bb
commit
3fdf85e79d
@ -13,3 +13,4 @@ MLX Examples was developed with contributions from the following individuals:
|
|||||||
- Gabrijel Boduljak: Implemented `CLIP`.
|
- Gabrijel Boduljak: Implemented `CLIP`.
|
||||||
- Markus Enzweiler: Added the `cvae` examples.
|
- Markus Enzweiler: Added the `cvae` examples.
|
||||||
- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example
|
- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example
|
||||||
|
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||||
|
@ -18,20 +18,12 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: int = None
|
||||||
max_position_embeddings: int = 16384
|
max_position_embeddings: int = 16384
|
||||||
norm_eps: float = None
|
norm_epsilon: float = 1e-5
|
||||||
rms_norm_eps: float = 1e-5
|
|
||||||
norm_type: str = "layer_norm"
|
norm_type: str = "layer_norm"
|
||||||
vocab_size: int = 49152
|
vocab_size: int = 49152
|
||||||
rope_theta: float = 100000
|
rope_theta: float = 100000
|
||||||
tie_word_embeddings: bool = True
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.num_key_value_heads is None:
|
|
||||||
self.num_key_value_heads = self.num_attention_heads
|
|
||||||
|
|
||||||
if self.norm_eps is None:
|
|
||||||
self.norm_eps = self.rms_norm_eps
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
@ -68,12 +60,9 @@ class Attention(nn.Module):
|
|||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
def repeat(a):
|
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
||||||
return a.reshape([B, self.n_heads, L, -1])
|
|
||||||
|
|
||||||
if self.repeats > 1:
|
if self.repeats > 1:
|
||||||
keys, values = map(repeat, (keys, values))
|
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||||
|
values = mx.repeat(values, self.repeats, axis=1)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
key_cache, value_cache = cache
|
||||||
@ -111,9 +100,9 @@ class TransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.self_attn = Attention(args)
|
self.self_attn = Attention(args)
|
||||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||||
self.post_attention_layernorm = LayerNorm(
|
self.post_attention_layernorm = LayerNorm(
|
||||||
args.hidden_size, eps=args.rms_norm_eps
|
args.hidden_size, eps=args.norm_epsilon
|
||||||
)
|
)
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
@ -141,7 +130,7 @@ class Starcoder2Model(nn.Module):
|
|||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
]
|
]
|
||||||
self.norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user