diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py new file mode 100644 index 00000000..d0ac2e16 --- /dev/null +++ b/llms/mlx_lm/models/dbrx.py @@ -0,0 +1,255 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + d_model: int + ffn_config: dict + attn_config: dict + n_layers: int + n_heads: int + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_heads = args.n_heads + self.d_model = args.d_model + self.head_dim = args.d_model // args.n_heads + self.num_key_value_heads = args.attn_config["kv_n_heads"] + self.clip_qkv = args.attn_config["clip_qkv"] + self.rope_theta = args.attn_config["rope_theta"] + + self.scale = self.head_dim**-0.5 + + self.Wqkv = nn.Linear( + args.d_model, + (self.num_key_value_heads * 2 + self.num_heads) * self.head_dim, + bias=False, + ) + self.out_proj = nn.Linear(args.d_model, args.d_model, bias=False) + self.rope = nn.RoPE( + self.head_dim, + traditional=False, + base=self.rope_theta, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + + qkv = self.Wqkv(x) + qkv = mx.clip(qkv, a_min=-self.clip_qkv, a_max=self.clip_qkv) + splits = [self.d_model, self.d_model + self.head_dim * self.num_key_value_heads] + queries, keys, values = mx.split(qkv, splits, axis=-1) + + B, L, D = x.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.out_proj(output), (keys, values) + + +class NormAttnNorm(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.norm_1 = nn.LayerNorm(args.d_model, bias=False) + self.norm_2 = nn.LayerNorm(args.d_model, bias=False) + self.attn = Attention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + h, cache = self.attn(self.norm_1(x), mask=mask, cache=cache) + x = h + x + return x, self.norm_2(x), cache + + +class MLP(nn.Module): + def __init__(self, d_model: int, ffn_dim: int): + super().__init__() + self.v1 = nn.Linear(d_model, ffn_dim, bias=False) + self.w1 = nn.Linear(d_model, ffn_dim, bias=False) + self.w2 = nn.Linear(ffn_dim, d_model, bias=False) + self.act_fn = nn.silu + + def __call__(self, x: mx.array) -> mx.array: + current_hidden_states = self.act_fn(self.w1(x)) * self.v1(x) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class Router(nn.Module): + def __init__(self, d_model: int, num_experts: int): + super().__init__() + self.layer = nn.Linear(d_model, num_experts, bias=False) + + def __call__(self, x: mx.array): + return self.layer(x) + + +class SparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.d_model = args.d_model + self.ffn_dim = args.ffn_config["ffn_hidden_size"] + self.num_experts = args.ffn_config["moe_num_experts"] + self.num_experts_per_tok = args.ffn_config["moe_top_k"] + + self.router = Router(self.d_model, self.num_experts) + self.experts = [ + MLP(self.d_model, self.ffn_dim) for _ in range(self.num_experts) + ] + + def __call__(self, x: mx.array) -> mx.array: + ne = self.num_experts_per_tok + orig_shape = x.shape + x = x.reshape(-1, x.shape[-1]) + + gates = self.router(x) + gates = mx.softmax(gates.astype(mx.float32), axis=-1) + + inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]) + scores = mx.take_along_axis(gates, inds, axis=-1) + scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True) + scores = scores.astype(x.dtype) + + if self.training: + inds = np.array(inds) + y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype) + for e, expert in enumerate(self.experts): + idx1, idx2 = map(mx.array, np.where(inds == e)) + if idx1.size == 0: + continue + y[idx1, idx2] = expert(x[idx1]) + + y = (y * scores[:, :, None]).sum(axis=1) + else: + y = [] + for xt, st, it in zip(x, scores, inds.tolist()): + yt = mx.stack([self.experts[e](xt) for e in it], axis=-1) + yt = (yt * st).sum(axis=-1) + y.append(yt) + y = mx.stack(y, axis=0) + + return y.reshape(orig_shape) + + +class DecoderLayer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.ffn = SparseMoeBlock(args) + self.norm_attn_norm = NormAttnNorm(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, h, cache = self.norm_attn_norm(x, mask, cache) + out = self.ffn(h) + r + return out, cache + + +class DBRX(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.vocab_size = args.vocab_size + self.wte = nn.Embedding(args.vocab_size, args.d_model) + self.blocks = [DecoderLayer(args=args) for _ in range(args.n_layers)] + self.norm_f = nn.LayerNorm(args.d_model, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.wte(inputs) + + mask = None + T = h.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.blocks) + + for e, layer in enumerate(self.blocks): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm_f(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.transformer = DBRX(args) + self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + self.args = args + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.transformer(inputs, cache) + return self.lm_head(out), cache + + @property + def layers(self): + return self.transformer.blocks + + def sanitize(self, weights): + # Split experts into sub matrices + num_experts = self.args.ffn_config["moe_num_experts"] + dim = self.args.ffn_config["ffn_hidden_size"] + + pattern = "experts.mlp" + new_weights = {k: v for k, v in weights.items() if pattern not in k} + for k, v in weights.items(): + if pattern in k: + experts = [ + (k.replace(".mlp", f".{e}") + ".weight", sv) + for e, sv in enumerate(mx.split(v, num_experts, axis=0)) + ] + if k.endswith("w2"): + experts = [(s, sv.T) for s, sv in experts] + new_weights.update(experts) + return new_weights diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 26a56779..6e8b7324 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -143,7 +143,6 @@ class MixtralSparseMoeBlock(nn.Module): ).astype(gates.dtype) if self.training: - mx.eval(inds) inds = np.array(inds) y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype) for e, expert in enumerate(self.experts): @@ -156,7 +155,7 @@ class MixtralSparseMoeBlock(nn.Module): else: y = [] for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1) + yt = mx.stack([self.experts[e](xt) for e in it], axis=-1) yt = (yt * st).sum(axis=-1) y.append(yt[None, :]) y = mx.concatenate(y) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 3849a70b..341e8984 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -125,7 +125,7 @@ class MOE(nn.Module): else: y = [] for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1) + yt = mx.stack([self.mlp[e](xt) for e in it], axis=-1) yt = (yt * st).sum(axis=-1) y.append(yt[None, :]) y = mx.concatenate(y) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index b465146c..e7113a36 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -60,6 +60,8 @@ def linear_to_lora_layers( keys = set(["att_proj"]) elif model.model_type == "phi-msft": keys = set(["mixer.Wqkv", "moe.gate"]) + elif model.model_type == "dbrx": + keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"]) else: raise ValueError(f"Lora does not support {model.model_type}")