From 1415595409874971b5ad96d55980e7d4aa8a8043 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Sat, 20 Jan 2024 06:07:45 -0800
Subject: [PATCH] chore(lora): support mixtral in lora example (#343)
---
lora/convert.py | 8 +-
lora/fuse.py | 2 +
lora/lora.py | 4 +-
lora/models.py | 7 +-
lora/models/mixtral.py | 253 +++++++++++++++++++++++++++++++++++++++++
lora/utils.py | 9 +-
6 files changed, 279 insertions(+), 4 deletions(-)
create mode 100644 lora/models/mixtral.py
diff --git a/lora/convert.py b/lora/convert.py
index 9b2f6de6..bc85eb5e 100644
--- a/lora/convert.py
+++ b/lora/convert.py
@@ -20,7 +20,13 @@ def quantize(weights, config, args):
model.load_weights(list(weights.items()))
# Quantize the model:
- nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
+ nn.QuantizedLinear.quantize_module(
+ model,
+ args.q_group_size,
+ args.q_bits,
+ linear_class_predicate=lambda m: isinstance(m, nn.Linear)
+ and m.weight.shape[0] != 8,
+ )
# Update the config:
quantized_config["quantization"] = {
diff --git a/lora/fuse.py b/lora/fuse.py
index bde543b4..2ea265fb 100644
--- a/lora/fuse.py
+++ b/lora/fuse.py
@@ -56,6 +56,8 @@ if __name__ == "__main__":
for l in model.model.layers[-lora_layers:]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
+ if hasattr(l, "block_sparse_moe"):
+ l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
model.update(tree_unflatten(adapters))
fused_linears = [
diff --git a/lora/lora.py b/lora/lora.py
index b522dfdb..9efe8893 100644
--- a/lora/lora.py
+++ b/lora/lora.py
@@ -315,6 +315,8 @@ if __name__ == "__main__":
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
+ if hasattr(l, "block_sparse_moe"):
+ l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
@@ -349,7 +351,7 @@ if __name__ == "__main__":
if args.test:
print("Testing")
-
+ model.eval()
test_loss = evaluate(
model,
test_set,
diff --git a/lora/models.py b/lora/models.py
index 244d8f5a..293b4f96 100644
--- a/lora/models.py
+++ b/lora/models.py
@@ -328,7 +328,12 @@ def load(path_or_hf_repo: str):
model = Model(model_args)
if quantization is not None:
- nn.QuantizedLinear.quantize_module(model, **quantization)
+ nn.QuantizedLinear.quantize_module(
+ model,
+ **quantization,
+ linear_class_predicate=lambda m: isinstance(m, nn.Linear)
+ and m.weight.shape[0] != 8,
+ )
model.load_weights(list(weights.items()))
diff --git a/lora/models/mixtral.py b/lora/models/mixtral.py
new file mode 100644
index 00000000..e70e0d2f
--- /dev/null
+++ b/lora/models/mixtral.py
@@ -0,0 +1,253 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+import numpy as np
+
+from .base import BaseModelArgs
+
+
+@dataclass
+class ModelArgs(BaseModelArgs):
+ vocab_size: int = 32000
+ max_position_embeddings: int = 4096 * 32
+ hidden_size: int = 4096
+ intermediate_size: int = 14336
+ num_hidden_layers: int = 32
+ num_attention_heads: int = 32
+ num_experts_per_tok: int = 2
+ num_key_value_heads: int = 8
+ num_local_experts: int = 8
+ rms_norm_eps: float = 1e-5
+ vocab_size: int
+ rope_theta: float = 1e6
+ rope_traditional: bool = False
+ model_type: str = None
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+
+ def __post_init__(self):
+ if self.num_key_value_heads is None:
+ self.num_key_value_heads = self.num_attention_heads
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dims: int, eps: float = 1e-5):
+ super().__init__()
+ self.weight = mx.ones((dims,))
+ self.eps = eps
+
+ def _norm(self, x):
+ return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
+
+ def __call__(self, x):
+ output = self._norm(x.astype(mx.float32)).astype(x.dtype)
+ return self.weight * output
+
+
+class MixtralAttention(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.hidden_size = args.hidden_size
+ self.num_heads = args.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = args.num_key_value_heads
+ self.max_position_embeddings = args.max_position_embeddings
+ self.rope_theta = args.rope_theta
+
+ self.repeats = self.num_heads // self.num_key_value_heads
+
+ self.scale = self.head_dim**-0.5
+
+ self.q_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
+ )
+ self.k_proj = nn.Linear(
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
+ )
+ self.v_proj = nn.Linear(
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
+ )
+ self.o_proj = nn.Linear(
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
+ )
+
+ self.rope = nn.RoPE(
+ self.head_dim,
+ traditional=args.rope_traditional,
+ base=args.rope_theta,
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
+ ) -> mx.array:
+ B, L, D = x.shape
+
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+ # Prepare the queries, keys and values for the attention computation
+ queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
+ keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
+ 0, 2, 1, 3
+ )
+
+ def repeat(a):
+ a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
+ return a.reshape([B, self.num_heads, L, -1])
+
+ if self.repeats > 1:
+ keys, values = map(repeat, (keys, values))
+
+ if cache is not None:
+ key_cache, value_cache = cache
+ queries = self.rope(queries, offset=key_cache.shape[2])
+ keys = self.rope(keys, offset=key_cache.shape[2])
+ keys = mx.concatenate([key_cache, keys], axis=2)
+ values = mx.concatenate([value_cache, values], axis=2)
+ else:
+ queries = self.rope(queries)
+ keys = self.rope(keys)
+
+ scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
+ if mask is not None:
+ scores += mask
+ scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
+ output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
+ return self.o_proj(output), (keys, values)
+
+
+class MixtralBLockSparseTop2MLP(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.ffn_dim = args.intermediate_size
+ self.hidden_dim = args.hidden_size
+
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+
+ self.act_fn = nn.silu
+
+ def __call__(self, x: mx.array) -> mx.array:
+ current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
+ current_hidden_states = self.w2(current_hidden_states)
+ return current_hidden_states
+
+
+class MixtralSparseMoeBlock(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.hidden_dim = args.hidden_size
+ self.ffn_dim = args.intermediate_size
+ self.num_experts = args.num_local_experts
+ self.num_experts_per_tok = args.num_experts_per_tok
+
+ # gating
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
+
+ self.experts = [
+ MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts)
+ ]
+
+ def __call__(self, x: mx.array) -> mx.array:
+ ne = self.num_experts_per_tok
+ orig_shape = x.shape
+ x = x.reshape(-1, x.shape[-1])
+
+ gates = self.gate(x)
+ inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne])
+
+ scores = mx.softmax(
+ mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
+ axis=-1,
+ ).astype(gates.dtype)
+
+ mx.eval(inds)
+ inds = np.array(inds)
+ y = mx.zeros((x.shape[0], ne, x.shape[-1]))
+ for e, expert in enumerate(self.experts):
+ idx1, idx2 = map(mx.array, np.where(inds == e))
+ if idx1.size == 0:
+ continue
+ y[idx1, idx2] = expert(x[idx1])
+
+ y = (y * scores[:, :, None]).sum(axis=1)
+
+ return y.reshape(orig_shape)
+
+
+class MixtralDecoderLayer(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.hidden_size = args.hidden_size
+
+ self.self_attn = MixtralAttention(args)
+
+ self.block_sparse_moe = MixtralSparseMoeBlock(args)
+ self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
+ ) -> mx.array:
+ r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
+ h = x + r
+ r = self.block_sparse_moe(self.post_attention_layernorm(h))
+ out = h + r
+ return out, cache
+
+
+class MixtralModel(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.vocab_size = args.vocab_size
+ self.num_hidden_layers = args.num_hidden_layers
+
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+ self.layers = [
+ MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
+ ]
+ self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ ):
+ h = self.embed_tokens(inputs)
+
+ mask = None
+ T = h.shape[1]
+ if T > 1:
+ mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
+ mask = mask.astype(h.dtype)
+
+ if cache is None:
+ cache = [None] * len(self.layers)
+
+ for e, layer in enumerate(self.layers):
+ h, cache[e] = layer(h, mask, cache[e])
+
+ return self.norm(h), cache
+
+
+class Model(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.model = MixtralModel(args)
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ ):
+ out, cache = self.model(inputs, cache)
+ return self.lm_head(out), cache
diff --git a/lora/utils.py b/lora/utils.py
index b691227d..80d59399 100644
--- a/lora/utils.py
+++ b/lora/utils.py
@@ -9,6 +9,7 @@ from typing import Generator
import mlx.core as mx
import mlx.nn as nn
import models.llama as llama
+import models.mixtral as mixtral
import models.phi2 as phi2
import transformers
from huggingface_hub import snapshot_download
@@ -18,6 +19,7 @@ MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"phi": phi2,
+ "mixtral": mixtral,
}
@@ -150,7 +152,12 @@ def load(path_or_hf_repo: str):
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if quantization is not None:
- nn.QuantizedLinear.quantize_module(model, **quantization)
+ nn.QuantizedLinear.quantize_module(
+ model,
+ **quantization,
+ linear_class_predicate=lambda m: isinstance(m, nn.Linear)
+ and m.weight.shape[0] != 8,
+ )
model.load_weights(list(weights.items()))