Starcoder2: Update config and change GQA to use repeat (#520)

* update config

* change gqa to use repeat instead of concante

* contribution
This commit is contained in:
Prince Canuma 2024-03-03 15:12:03 +01:00 committed by GitHub
parent 1e3daea3bb
commit 3fdf85e79d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 17 deletions

View File

@ -13,3 +13,4 @@ MLX Examples was developed with contributions from the following individuals:
- Gabrijel Boduljak: Implemented `CLIP`. - Gabrijel Boduljak: Implemented `CLIP`.
- Markus Enzweiler: Added the `cvae` examples. - Markus Enzweiler: Added the `cvae` examples.
- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example - Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example
- Prince Canuma: Helped add support for `Starcoder2` models.

View File

@ -18,20 +18,12 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
num_key_value_heads: int = None num_key_value_heads: int = None
max_position_embeddings: int = 16384 max_position_embeddings: int = 16384
norm_eps: float = None norm_epsilon: float = 1e-5
rms_norm_eps: float = 1e-5
norm_type: str = "layer_norm" norm_type: str = "layer_norm"
vocab_size: int = 49152 vocab_size: int = 49152
rope_theta: float = 100000 rope_theta: float = 100000
tie_word_embeddings: bool = True tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.norm_eps is None:
self.norm_eps = self.rms_norm_eps
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
@ -68,12 +60,9 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
if self.repeats > 1: if self.repeats > 1:
keys, values = map(repeat, (keys, values)) keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
@ -111,9 +100,9 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
args.hidden_size, eps=args.rms_norm_eps args.hidden_size, eps=args.norm_epsilon
) )
self.args = args self.args = args
@ -141,7 +130,7 @@ class Starcoder2Model(nn.Module):
self.layers = [ self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
] ]
self.norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
def __call__( def __call__(
self, self,