2024-08-17 06:28:39 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-05-22 06:58:08 +08:00
|
|
|
import math
|
2024-01-15 23:18:14 +08:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Dict, Optional, Tuple, Union
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
|
2024-07-26 07:45:22 +08:00
|
|
|
from .base import BaseModelArgs, create_attention_mask
|
2024-05-22 06:58:08 +08:00
|
|
|
from .switch_layers import SwitchGLU
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ModelArgs(BaseModelArgs):
|
Mixtral: Fix non-default arg follows default exception (#450)
Mixtral models throw the following exception
```
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 119, in <module>
main(args)
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 96, in main
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 278, in load
model = load_model(model_path)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 221, in load_model
model_class, model_args_class = _get_classes(config=config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 46, in _get_classes
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 940, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/models/mixtral.py", line 11, in <module>
@dataclass
^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1230, in dataclass
return wrap(cls)
^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1220, in wrap
return _process_class(cls, init, repr, eq, order, unsafe_hash,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1027, in _process_class
_init_fn(all_init_fields,
File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 545, in _init_fn
raise TypeError(f'non-default argument {f.name!r} '
TypeError: non-default argument 'model_type' follows default argument
```
2024-02-19 05:30:26 +08:00
|
|
|
model_type: str
|
2024-01-15 23:18:14 +08:00
|
|
|
vocab_size: int = 32000
|
|
|
|
hidden_size: int = 4096
|
|
|
|
intermediate_size: int = 14336
|
|
|
|
num_hidden_layers: int = 32
|
|
|
|
num_attention_heads: int = 32
|
|
|
|
num_experts_per_tok: int = 2
|
|
|
|
num_key_value_heads: int = 8
|
|
|
|
num_local_experts: int = 8
|
|
|
|
rms_norm_eps: float = 1e-5
|
|
|
|
rope_theta: float = 1e6
|
|
|
|
rope_traditional: bool = False
|
|
|
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
if self.num_key_value_heads is None:
|
|
|
|
self.num_key_value_heads = self.num_attention_heads
|
|
|
|
|
|
|
|
|
|
|
|
class MixtralAttention(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.hidden_size = args.hidden_size
|
|
|
|
self.num_heads = args.num_attention_heads
|
|
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
|
|
self.num_key_value_heads = args.num_key_value_heads
|
|
|
|
self.rope_theta = args.rope_theta
|
|
|
|
|
|
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(
|
|
|
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
|
|
)
|
|
|
|
self.k_proj = nn.Linear(
|
|
|
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
|
|
|
)
|
|
|
|
self.v_proj = nn.Linear(
|
|
|
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
|
|
|
)
|
|
|
|
self.o_proj = nn.Linear(
|
|
|
|
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
|
|
|
)
|
|
|
|
|
|
|
|
self.rope = nn.RoPE(
|
|
|
|
self.head_dim,
|
|
|
|
traditional=args.rope_traditional,
|
|
|
|
base=args.rope_theta,
|
|
|
|
)
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
x: mx.array,
|
|
|
|
mask: Optional[mx.array] = None,
|
|
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
|
|
) -> mx.array:
|
|
|
|
B, L, D = x.shape
|
|
|
|
|
|
|
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
|
|
|
|
|
|
# Prepare the queries, keys and values for the attention computation
|
|
|
|
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
|
|
|
0, 2, 1, 3
|
|
|
|
)
|
|
|
|
|
|
|
|
if cache is not None:
|
2024-05-08 23:18:13 +08:00
|
|
|
queries = self.rope(queries, offset=cache.offset)
|
|
|
|
keys = self.rope(keys, offset=cache.offset)
|
|
|
|
keys, values = cache.update_and_fetch(keys, values)
|
2024-01-15 23:18:14 +08:00
|
|
|
else:
|
|
|
|
queries = self.rope(queries)
|
|
|
|
keys = self.rope(keys)
|
|
|
|
|
2024-03-15 12:35:54 +08:00
|
|
|
output = mx.fast.scaled_dot_product_attention(
|
|
|
|
queries, keys, values, scale=self.scale, mask=mask
|
|
|
|
)
|
|
|
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
2024-05-08 23:18:13 +08:00
|
|
|
return self.o_proj(output)
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
class MixtralSparseMoeBlock(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.hidden_dim = args.hidden_size
|
|
|
|
self.ffn_dim = args.intermediate_size
|
|
|
|
self.num_experts = args.num_local_experts
|
|
|
|
self.num_experts_per_tok = args.num_experts_per_tok
|
|
|
|
|
|
|
|
# gating
|
|
|
|
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
|
|
|
2024-05-22 06:58:08 +08:00
|
|
|
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
|
|
gates = self.gate(x)
|
2024-01-24 00:44:37 +08:00
|
|
|
|
2024-05-22 06:58:08 +08:00
|
|
|
k = self.num_experts_per_tok
|
|
|
|
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
|
|
|
scores = mx.take_along_axis(gates, inds, axis=-1)
|
|
|
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
2024-01-24 00:44:37 +08:00
|
|
|
|
2024-05-22 06:58:08 +08:00
|
|
|
y = self.switch_mlp(x, inds)
|
|
|
|
y = (y * scores[..., None]).sum(axis=-2)
|
2024-01-15 23:18:14 +08:00
|
|
|
|
2024-05-22 06:58:08 +08:00
|
|
|
return y
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
class MixtralDecoderLayer(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.hidden_size = args.hidden_size
|
|
|
|
|
|
|
|
self.self_attn = MixtralAttention(args)
|
|
|
|
|
|
|
|
self.block_sparse_moe = MixtralSparseMoeBlock(args)
|
2024-03-23 22:13:51 +08:00
|
|
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
|
|
self.post_attention_layernorm = nn.RMSNorm(
|
|
|
|
args.hidden_size, eps=args.rms_norm_eps
|
|
|
|
)
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
x: mx.array,
|
|
|
|
mask: Optional[mx.array] = None,
|
|
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
|
|
) -> mx.array:
|
2024-05-08 23:18:13 +08:00
|
|
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
2024-01-15 23:18:14 +08:00
|
|
|
h = x + r
|
|
|
|
r = self.block_sparse_moe(self.post_attention_layernorm(h))
|
|
|
|
out = h + r
|
2024-05-08 23:18:13 +08:00
|
|
|
return out
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
class MixtralModel(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.vocab_size = args.vocab_size
|
|
|
|
self.num_hidden_layers = args.num_hidden_layers
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
|
|
self.layers = [
|
|
|
|
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
|
|
|
]
|
2024-03-23 22:13:51 +08:00
|
|
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
inputs: mx.array,
|
|
|
|
cache=None,
|
|
|
|
):
|
|
|
|
h = self.embed_tokens(inputs)
|
|
|
|
|
2024-07-26 07:45:22 +08:00
|
|
|
mask = create_attention_mask(h, cache)
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
if cache is None:
|
|
|
|
cache = [None] * len(self.layers)
|
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
for layer, c in zip(self.layers, cache):
|
|
|
|
h = layer(h, mask, c)
|
2024-01-15 23:18:14 +08:00
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
return self.norm(h)
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
2024-02-13 02:51:02 +08:00
|
|
|
self.model_type = args.model_type
|
2024-01-15 23:18:14 +08:00
|
|
|
self.model = MixtralModel(args)
|
|
|
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
2024-05-08 23:18:13 +08:00
|
|
|
self.args = args
|
2024-01-15 23:18:14 +08:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
inputs: mx.array,
|
|
|
|
cache=None,
|
|
|
|
):
|
2024-05-08 23:18:13 +08:00
|
|
|
out = self.model(inputs, cache)
|
|
|
|
return self.lm_head(out)
|
2024-02-20 12:37:15 +08:00
|
|
|
|
2024-05-22 06:58:08 +08:00
|
|
|
def sanitize(self, weights):
|
|
|
|
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
|
|
|
|
return weights
|
|
|
|
for l in range(self.args.num_hidden_layers):
|
|
|
|
prefix = f"model.layers.{l}"
|
|
|
|
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
|
|
|
for k in ["weight", "scales", "biases"]:
|
2024-06-01 03:36:05 +08:00
|
|
|
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
|
|
|
|
to_join = [
|
|
|
|
weights.pop(
|
|
|
|
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
|
|
|
|
)
|
|
|
|
for e in range(self.args.num_local_experts)
|
|
|
|
]
|
2024-05-22 06:58:08 +08:00
|
|
|
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
|
|
|
|
mx.stack(to_join)
|
|
|
|
)
|
|
|
|
return weights
|
|
|
|
|
2024-02-20 12:37:15 +08:00
|
|
|
@property
|
|
|
|
def layers(self):
|
|
|
|
return self.model.layers
|
2024-05-08 23:18:13 +08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def head_dim(self):
|
|
|
|
return self.args.hidden_size // self.args.num_attention_heads
|
|
|
|
|
|
|
|
@property
|
|
|
|
def n_kv_heads(self):
|
|
|
|
return self.args.num_key_value_heads
|