diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 41557c29..851c995c 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals: - Markus Enzweiler: Added the `cvae` examples. - Prince Canuma: Helped add support for `Starcoder2` models. - Shiyu Li: Added the `Segment Anything Model`. -- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`. \ No newline at end of file +- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`. \ No newline at end of file diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py new file mode 100644 index 00000000..6ca46a72 --- /dev/null +++ b/llms/mlx_lm/models/helium.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention + + +@dataclass +class ModelArgs(BaseModelArgs): + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int + rms_norm_eps: float + vocab_size: int + attention_bias: bool + head_dim: int + max_position_embeddings: int + mlp_bias: bool + model_type: str + rope_theta: float + tie_word_embeddings: bool + + +class HeliumAttention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class HeliumMLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.hidden_size = args.hidden_size + self.intermediate_size = args.intermediate_size + + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=args.mlp_bias + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=args.mlp_bias + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=args.mlp_bias + ) + + def __call__(self, x: mx.array) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class HeliumDecoderLayer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.hidden_size = args.hidden_size + + self.self_attn = HeliumAttention(args) + self.mlp = HeliumMLP(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class HeliumModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_hidden_layers = args.num_hidden_layers + self.vocab_size = args.vocab_size + + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + + self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)] + + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ) -> mx.array: + h = self.embed_tokens(inputs) + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + + self.model = HeliumModel(args) + + self.vocab_size = args.vocab_size + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ) -> mx.array: + out = self.model(inputs, mask, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 594f8040..c0e52731 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -94,6 +94,7 @@ def linear_to_lora_layers( "phimoe", "gemma", "gemma2", + "helium", "starcoder2", "cohere", "cohere2",