MiniCPM implementation (#685)

* Added support for the MiniCPM architecture

* Added support for the MiniCPM architecture

* Updated utils.py and LORA.md

* Updated utils.py and LORA.md

* Update implementation details for MiniCPM architecture

* Cleaning up

* fixed the missing lm.head layer problem

* Refactor Model class to dynamically handle tied and untied word embeddings

* Quick update

* added a dynamic rope scaling base calucaltion

* Added support for the MiniCPM architecture

* Added support for the MiniCPM architecture

* Updated utils.py and LORA.md

* Updated utils.py and LORA.md

* Update implementation details for MiniCPM architecture

* Cleaning up

* fixed the missing lm.head layer problem

* Refactor Model class to dynamically handle tied and untied word embeddings

* added a dynamic rope scaling base calucaltion

* quick fix and clean up

* clean up again

* removed the MiniCPMNorm class as its not used

* forgot something, sorry

* format

* version bump

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gökdeniz Gülmez 2024-04-26 00:29:28 +02:00 committed by GitHub
parent 685012c2ad
commit 2c1c9e9024
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 251 additions and 22 deletions

View File

@ -11,16 +11,17 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- Qwen2 - Qwen2
- Gemma - Gemma
- OLMo - OLMo
- MiniCPM
## Contents ## Contents
* [Run](#Run) - [Run](#Run)
* [Fine-tune](#Fine-tune) - [Fine-tune](#Fine-tune)
* [Evaluate](#Evaluate) - [Evaluate](#Evaluate)
* [Generate](#Generate) - [Generate](#Generate)
* [Fuse](#Fuse) - [Fuse](#Fuse)
* [Data](#Data) - [Data](#Data)
* [Memory Issues](#Memory-Issues) - [Memory Issues](#Memory-Issues)
## Run ## Run
@ -122,7 +123,7 @@ To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
useful for the sake of attribution and model versioning. useful for the sake of attribution and model versioning.
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
```shell ```shell
mlx_lm.fuse \ mlx_lm.fuse \
@ -144,38 +145,54 @@ can specify the file name with `--gguf-path`.
## Data ## Data
The LoRA command expects you to provide a dataset with `--data`. The MLX The LoRA command expects you to provide a dataset with `--data`. The MLX
Examples GitHub repo has an [example of the WikiSQL Examples GitHub repo has an [example of the WikiSQL
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
correct format. correct format.
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory. loader expects a `test.jsonl` in the data directory.
Currently, `*.jsonl` files support three data formats: `chat`, Currently, `*.jsonl` files support three data formats: `chat`,
`completions`, and `text`. Here are three examples of these formats: `completions`, and `text`. Here are three examples of these formats:
`chat`: `chat`:
```jsonl ```jsonl
{"messages": [ {
{"role": "system", "content": "You are a helpful assistant." }, "messages": [
{"role": "user", "content": "Hello."}, {
{"role": "assistant", "content": "How can I assistant you today."}, "role": "system",
]} "content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello."
},
{
"role": "assistant",
"content": "How can I assistant you today."
}
]
}
``` ```
`completions`: `completions`:
```jsonl ```jsonl
{"prompt": "What is the capital of France?", "completion": "Paris."} {
"prompt": "What is the capital of France?",
"completion": "Paris."
}
``` ```
`text`: `text`:
```jsonl ```jsonl
{"text": "This is an example for the model."} {
"text": "This is an example for the model."
}
``` ```
Note, the format is automatically determined by the dataset. Note also, keys in Note, the format is automatically determined by the dataset. Note also, keys in
@ -207,7 +224,7 @@ of memory. Here are some tips to reduce memory use should you need to do so:
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model 1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
more details. more details.
2. Try using a smaller batch size with `--batch-size`. The default is `4` so 2. Try using a smaller batch size with `--batch-size`. The default is `4` so
setting this to `2` or `1` will reduce memory consumption. This may slow setting this to `2` or `1` will reduce memory consumption. This may slow
@ -244,6 +261,5 @@ tokens-per-second, using the MLX Example
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) [`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
data set. data set.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)

View File

@ -0,0 +1,212 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
dim_model_base: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
max_position_embeddings: int
scale_depth: float
scale_emb: float
rope_theta: float = 1000000.0
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[str, float]]] = None
tie_word_embeddings: bool = False
class MLP(nn.Module):
def __init__(self, args):
super().__init__()
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.num_heads = n_heads = args.num_attention_heads
self.rope_theta = args.rope_theta
self.max_position_embeddings = args.max_position_embeddings
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.num_key_value_heads = args.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
dims=self.head_dim,
traditional=args.rope_traditional,
base=self.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
):
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
attn_output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(attn_output), (keys, values)
class DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.num_hidden_layers = args.num_hidden_layers
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.scale_depth = args.scale_depth
self.num_hidden_layers = args.num_hidden_layers
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
r = self.mlp(self.post_attention_layernorm(h))
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
return out, cache
class MiniCPMModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs) * self.args.scale_emb
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = MiniCPMModel(args)
if not self.args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
if not self.args.tie_word_embeddings:
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
else:
out = out @ self.model.embed_tokens.weight.T
return out, cache
def sanitize(self, weights):
if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return weights
@property
def layers(self):
return self.model.layers

View File

@ -77,6 +77,7 @@ def linear_to_lora_layers(
"gemma", "gemma",
"starcoder2", "starcoder2",
"cohere", "cohere",
"minicpm",
]: ]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"]) keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type == "mixtral": if model.model_type == "mixtral":

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.10.0" __version__ = "0.12.0"