From 2c1c9e902481537062278eb6ae3d45fa9aa4bd38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Fri, 26 Apr 2024 00:29:28 +0200 Subject: [PATCH] MiniCPM implementation (#685) * Added support for the MiniCPM architecture * Added support for the MiniCPM architecture * Updated utils.py and LORA.md * Updated utils.py and LORA.md * Update implementation details for MiniCPM architecture * Cleaning up * fixed the missing lm.head layer problem * Refactor Model class to dynamically handle tied and untied word embeddings * Quick update * added a dynamic rope scaling base calucaltion * Added support for the MiniCPM architecture * Added support for the MiniCPM architecture * Updated utils.py and LORA.md * Updated utils.py and LORA.md * Update implementation details for MiniCPM architecture * Cleaning up * fixed the missing lm.head layer problem * Refactor Model class to dynamically handle tied and untied word embeddings * added a dynamic rope scaling base calucaltion * quick fix and clean up * clean up again * removed the MiniCPMNorm class as its not used * forgot something, sorry * format * version bump --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 58 ++++++---- llms/mlx_lm/models/minicpm.py | 212 ++++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 1 + llms/mlx_lm/version.py | 2 +- 4 files changed, 251 insertions(+), 22 deletions(-) create mode 100644 llms/mlx_lm/models/minicpm.py diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 6d9392d5..94206ad3 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -11,16 +11,17 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: - Qwen2 - Gemma - OLMo +- MiniCPM ## Contents -* [Run](#Run) - * [Fine-tune](#Fine-tune) - * [Evaluate](#Evaluate) - * [Generate](#Generate) -* [Fuse](#Fuse) -* [Data](#Data) -* [Memory Issues](#Memory-Issues) +- [Run](#Run) + - [Fine-tune](#Fine-tune) + - [Evaluate](#Evaluate) + - [Generate](#Generate) +- [Fuse](#Fuse) +- [Data](#Data) +- [Memory Issues](#Memory-Issues) ## Run @@ -122,7 +123,7 @@ To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments to `mlx_lm.fuse`. The latter is the repo name of the original model, which is useful for the sake of attribution and model versioning. -For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: +For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: ```shell mlx_lm.fuse \ @@ -144,38 +145,54 @@ can specify the file name with `--gguf-path`. ## Data -The LoRA command expects you to provide a dataset with `--data`. The MLX +The LoRA command expects you to provide a dataset with `--data`. The MLX Examples GitHub repo has an [example of the WikiSQL data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the correct format. For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data -loader expects a `test.jsonl` in the data directory. +loader expects a `test.jsonl` in the data directory. Currently, `*.jsonl` files support three data formats: `chat`, `completions`, and `text`. Here are three examples of these formats: `chat`: - + ```jsonl -{"messages": [ - {"role": "system", "content": "You are a helpful assistant." }, - {"role": "user", "content": "Hello."}, - {"role": "assistant", "content": "How can I assistant you today."}, -]} +{ + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello." + }, + { + "role": "assistant", + "content": "How can I assistant you today." + } + ] +} ``` `completions`: - + ```jsonl -{"prompt": "What is the capital of France?", "completion": "Paris."} +{ + "prompt": "What is the capital of France?", + "completion": "Paris." +} ``` `text`: ```jsonl -{"text": "This is an example for the model."} +{ + "text": "This is an example for the model." +} ``` Note, the format is automatically determined by the dataset. Note also, keys in @@ -207,7 +224,7 @@ of memory. Here are some tips to reduce memory use should you need to do so: 1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model with `convert.py` and the `-q` flag. See the [Setup](#setup) section for - more details. + more details. 2. Try using a smaller batch size with `--batch-size`. The default is `4` so setting this to `2` or `1` will reduce memory consumption. This may slow @@ -244,6 +261,5 @@ tokens-per-second, using the MLX Example [`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) data set. - [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py new file mode 100644 index 00000000..d3119f71 --- /dev/null +++ b/llms/mlx_lm/models/minicpm.py @@ -0,0 +1,212 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + dim_model_base: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int + max_position_embeddings: int + scale_depth: float + scale_emb: float + rope_theta: float = 1000000.0 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[str, float]]] = None + tie_word_embeddings: bool = False + + +class MLP(nn.Module): + def __init__(self, args): + super().__init__() + self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) + + def __call__(self, x): + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.hidden_size = args.hidden_size + self.num_heads = n_heads = args.num_attention_heads + self.rope_theta = args.rope_theta + self.max_position_embeddings = args.max_position_embeddings + + self.head_dim = head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.num_key_value_heads = args.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + rope_scale = ( + 1 / args.rope_scaling["factor"] + if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" + else 1 + ) + + self.rope = nn.RoPE( + dims=self.head_dim, + traditional=args.rope_traditional, + base=self.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ): + B, L, _ = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + attn_output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + + attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(attn_output), (keys, values) + + +class DecoderLayer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.hidden_size = args.hidden_size + self.num_hidden_layers = args.num_hidden_layers + + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + self.scale_depth = args.scale_depth + self.num_hidden_layers = args.num_hidden_layers + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) + return out, cache + + +class MiniCPMModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + assert self.vocab_size > 0 + + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) * self.args.scale_emb + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = MiniCPMModel(args) + + if not self.args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + + if not self.args.tie_word_embeddings: + out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) + else: + out = out @ self.model.embed_tokens.weight.T + + return out, cache + + def sanitize(self, weights): + if "lm_head.weight" not in weights: + weights["lm_head.weight"] = weights["model.embed_tokens.weight"] + return weights + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 0e1cdcdd..6662a038 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -77,6 +77,7 @@ def linear_to_lora_layers( "gemma", "starcoder2", "cohere", + "minicpm", ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type == "mixtral": diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index f907220b..3e8b1fe1 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.10.0" +__version__ = "0.12.0"