diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 29fdcdcb..0c723c86 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -9,7 +9,8 @@ from mlx.utils import tree_flatten from .tuner.lora import LoRALinear from .tuner.trainer import TrainingArgs, evaluate, train -from .utils import generate, load, LORA_SUPPORTED_MODELS +from .utils import LORA_SUPPORTED_MODELS, generate, load + def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") @@ -203,7 +204,7 @@ if __name__ == "__main__": steps_per_eval=args.steps_per_eval, steps_per_save=args.save_every, adapter_file=args.adapter_file, - max_seq_length=args.max_seq_length + max_seq_length=args.max_seq_length, ) if args.train: print("Training") diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py new file mode 100644 index 00000000..8a511a1e --- /dev/null +++ b/llms/mlx_lm/models/olmo.py @@ -0,0 +1,159 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + +try: + import hf_olmo +except ImportError: + print("To run olmo install ai2-olmo: pip install ai2-olmo") + exit(1) + + +@dataclass +class ModelArgs(BaseModelArgs): + d_model: int + n_layers: int + mlp_hidden_size: int + n_heads: int + vocab_size: int + embedding_size: int + rope_theta: float = 10000 + rope_traditional: bool = False + model_type: str = None + + +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + dim = args.d_model + + self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False) + self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False) + + self.att_norm = LayerNorm(dim, affine=False) + self.ff_norm = LayerNorm(dim, affine=False) + + head_dim = dim // self.n_heads + self.scale = head_dim**-0.5 + + self.att_proj = nn.Linear(dim, 3 * dim, bias=False) + self.attn_out = nn.Linear(dim, dim, bias=False) + + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + ) + + self.args = args + + def attend( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.attn_out(output), (keys, values) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attend(self.att_norm(x), mask, cache) + h = x + r + + x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1) + out = h + self.ff_out(nn.silu(x2) * x1) + return out, cache + + +class Transformer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_layers = args.n_layers + self.wte = nn.Embedding(args.embedding_size, args.d_model) + self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)] + self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False) + self.norm = LayerNorm(args.d_model, affine=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.wte(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.blocks) + + for e, block in enumerate(self.blocks): + h, cache[e] = block(h, mask, cache[e]) + + return self.ff_out(self.norm(h)), cache + + +class OlmoModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.transformer = Transformer(args) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + return self.transformer(inputs, cache) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model = OlmoModel(args) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + return self.model(inputs, cache) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0bac0489..61b8d9c0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -12,7 +12,7 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer # Local imports -from .models import llama, mixtral, phi2, plamo, qwen, stablelm_epoch, qwen2 +from .models import llama, mixtral, olmo, phi2, plamo, qwen, qwen2, stablelm_epoch from .tuner.utils import apply_lora_layers # Constants @@ -24,10 +24,15 @@ MODEL_MAPPING = { "stablelm_epoch": stablelm_epoch, "qwen": qwen, "plamo": plamo, - "qwen2": qwen2 + "olmo": olmo, + "qwen2": qwen2, } LORA_SUPPORTED_MODELS = [ - llama.Model, mixtral.Model, phi2.Model, stablelm_epoch.Model, qwen2.Model + llama.Model, + mixtral.Model, + phi2.Model, + stablelm_epoch.Model, + qwen2.Model, ] MAX_FILE_SIZE_GB = 5 diff --git a/llms/setup.py b/llms/setup.py index 0772501a..00fd0e69 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -8,7 +8,7 @@ with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid: requirements = [str(r) for r in pkg_resources.parse_requirements(fid)] setup( name="mlx-lm", - version="0.0.6", + version="0.0.8", description="LLMs on Apple silicon with MLX and the Hugging Face Hub", long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown",