mlx-examples/llms/mlx_lm/models/mixtral.py
devonthomas35 cc671cd1c7
Mixtral: Fix non-default arg follows default exception (#450)
Mixtral models throw the following exception
```
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 119, in <module>
    main(args)
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 96, in main
    model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 278, in load
    model = load_model(model_path)
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 221, in load_model
    model_class, model_args_class = _get_classes(config=config)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 46, in _get_classes
    arch = importlib.import_module(f"mlx_lm.models.{model_type}")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/models/mixtral.py", line 11, in <module>
    @dataclass
     ^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1230, in dataclass
    return wrap(cls)
           ^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1220, in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1027, in _process_class
    _init_fn(all_init_fields,
  File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 545, in _init_fn
    raise TypeError(f'non-default argument {f.name!r} '
TypeError: non-default argument 'model_type' follows default argument
```
2024-02-18 13:30:26 -08:00

263 lines
8.4 KiB
Python

from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
vocab_size: int = 32000
max_position_embeddings: int = 4096 * 32
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_experts_per_tok: int = 2
num_key_value_heads: int = 8
num_local_experts: int = 8
rms_norm_eps: float = 1e-5
rope_theta: float = 1e6
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class MixtralAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.num_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = args.num_key_value_heads
self.max_position_embeddings = args.max_position_embeddings
self.rope_theta = args.rope_theta
self.repeats = self.num_heads // self.num_key_value_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.rope = nn.RoPE(
self.head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MixtralBLockSparseTop2MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ffn_dim = args.intermediate_size
self.hidden_dim = args.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MixtralSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.intermediate_size
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = [
MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts)
]
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.stop_gradient(
mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
) # TODO remove it once we figure out how to fine tune TopK in MOE
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
if self.training:
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
class MixtralDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = MixtralAttention(args)
self.block_sparse_moe = MixtralSparseMoeBlock(args)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.block_sparse_moe(self.post_attention_layernorm(h))
out = h + r
return out, cache
class MixtralModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = MixtralModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache