chore(lora): support mixtral in lora example (#343)

This commit is contained in:
Anchen 2024-01-20 06:07:45 -08:00 committed by GitHub
parent 527cea4027
commit 1415595409
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 279 additions and 4 deletions

View File

@ -20,7 +20,13 @@ def quantize(weights, config, args):
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) nn.QuantizedLinear.quantize_module(
model,
args.q_group_size,
args.q_bits,
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)
# Update the config: # Update the config:
quantized_config["quantization"] = { quantized_config["quantization"] = {

View File

@ -56,6 +56,8 @@ if __name__ == "__main__":
for l in model.model.layers[-lora_layers:]: for l in model.model.layers[-lora_layers:]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
if hasattr(l, "block_sparse_moe"):
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
model.update(tree_unflatten(adapters)) model.update(tree_unflatten(adapters))
fused_linears = [ fused_linears = [

View File

@ -315,6 +315,8 @@ if __name__ == "__main__":
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
if hasattr(l, "block_sparse_moe"):
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M") print(f"Total parameters {p:.3f}M")
@ -349,7 +351,7 @@ if __name__ == "__main__":
if args.test: if args.test:
print("Testing") print("Testing")
model.eval()
test_loss = evaluate( test_loss = evaluate(
model, model,
test_set, test_set,

View File

@ -328,7 +328,12 @@ def load(path_or_hf_repo: str):
model = Model(model_args) model = Model(model_args)
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))

253
lora/models/mixtral.py Normal file
View File

@ -0,0 +1,253 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
vocab_size: int = 32000
max_position_embeddings: int = 4096 * 32
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_experts_per_tok: int = 2
num_key_value_heads: int = 8
num_local_experts: int = 8
rms_norm_eps: float = 1e-5
vocab_size: int
rope_theta: float = 1e6
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class MixtralAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.num_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = args.num_key_value_heads
self.max_position_embeddings = args.max_position_embeddings
self.rope_theta = args.rope_theta
self.repeats = self.num_heads // self.num_key_value_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.rope = nn.RoPE(
self.head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.num_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MixtralBLockSparseTop2MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ffn_dim = args.intermediate_size
self.hidden_dim = args.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MixtralSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.intermediate_size
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = [
MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts)
]
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne])
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
return y.reshape(orig_shape)
class MixtralDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = MixtralAttention(args)
self.block_sparse_moe = MixtralSparseMoeBlock(args)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.block_sparse_moe(self.post_attention_layernorm(h))
out = h + r
return out, cache
class MixtralModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = MixtralModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache

View File

@ -9,6 +9,7 @@ from typing import Generator
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import models.llama as llama import models.llama as llama
import models.mixtral as mixtral
import models.phi2 as phi2 import models.phi2 as phi2
import transformers import transformers
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -18,6 +19,7 @@ MODEL_MAPPING = {
"llama": llama, "llama": llama,
"mistral": llama, # mistral is compatible with llama "mistral": llama, # mistral is compatible with llama
"phi": phi2, "phi": phi2,
"mixtral": mixtral,
} }
@ -150,7 +152,12 @@ def load(path_or_hf_repo: str):
model_args = model_args_class.from_dict(config) model_args = model_args_class.from_dict(config)
model = model_class(model_args) model = model_class(model_args)
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))