mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Block sparse MM MoEs (#782)
- Adds SwitchLinear - Adds QuantizedSwitchLinear
This commit is contained in:
parent
199df9e110
commit
9f671228cd
@ -7,7 +7,7 @@ from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from .gguf import convert_to_gguf
|
||||
from .tuner.dora import DoRALinear
|
||||
from .tuner.lora import LoRALinear
|
||||
from .tuner.lora import LoRALinear, LoRASwitchLinear
|
||||
from .tuner.utils import apply_lora_layers, dequantize
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
@ -82,7 +82,7 @@ def main() -> None:
|
||||
fused_linears = [
|
||||
(n, m.to_linear())
|
||||
for n, m in model.named_modules()
|
||||
if isinstance(m, (LoRALinear, DoRALinear))
|
||||
if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear))
|
||||
]
|
||||
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
|
@ -1,11 +1,12 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -91,24 +92,6 @@ class MixtralAttention(nn.Module):
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MixtralBLockSparseTop2MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.ffn_dim = args.intermediate_size
|
||||
self.hidden_dim = args.hidden_size
|
||||
|
||||
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
||||
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
||||
|
||||
self.act_fn = nn.silu
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class MixtralSparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@ -120,43 +103,20 @@ class MixtralSparseMoeBlock(nn.Module):
|
||||
# gating
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
|
||||
self.experts = [
|
||||
MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts)
|
||||
]
|
||||
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.gate(x)
|
||||
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
|
||||
k = self.num_experts_per_tok
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
scores = mx.softmax(
|
||||
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||
axis=-1,
|
||||
).astype(gates.dtype)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
if self.training:
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
return y
|
||||
|
||||
|
||||
class MixtralDecoderLayer(nn.Module):
|
||||
@ -235,6 +195,23 @@ class Model(nn.Module):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
if to_join:
|
||||
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
|
||||
mx.stack(to_join)
|
||||
)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
@ -5,7 +5,8 @@ from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .switch_layers import SwitchMLP
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -75,17 +76,6 @@ class RoPEAttention(nn.Module):
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class MOE(nn.Module):
|
||||
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
@ -93,40 +83,23 @@ class MOE(nn.Module):
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_experts = args.num_local_experts
|
||||
self.num_experts_per_tok = args.num_experts_per_tok
|
||||
self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)]
|
||||
self.switch_mlp = SwitchMLP(
|
||||
self.dim, self.hidden_dim, self.num_experts, bias=True
|
||||
)
|
||||
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.gate(x)
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1))[:, :ne]
|
||||
scores = mx.softmax(
|
||||
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||
axis=-1,
|
||||
).astype(gates.dtype)
|
||||
|
||||
if self.training:
|
||||
ys = []
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.mlp):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
k = self.num_experts_per_tok
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
y = (y * scores[..., None]).sum(axis=1)
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.stack([self.mlp[e](xt) for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
return y
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
@ -202,6 +175,21 @@ class Model(nn.Module):
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_layers):
|
||||
prefix = f"transformer.h.{l}"
|
||||
for n in ["fc1", "fc2"]:
|
||||
for k in ["weight", "scales", "biases", "bias"]:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
if to_join:
|
||||
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
@ -1,11 +1,12 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -92,7 +93,7 @@ class Attention(nn.Module):
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
@ -113,57 +114,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
self.num_experts = num_experts = args.num_experts
|
||||
self.top_k = args.num_experts_per_tok
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(dim, num_experts, bias=False)
|
||||
self.experts = [
|
||||
Qwen2MoeMLP(dim, intermediate_size) for _ in range(self.num_experts)
|
||||
]
|
||||
self.shared_expert = Qwen2MoeMLP(dim, shared_expert_intermediate_size)
|
||||
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
|
||||
|
||||
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
|
||||
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
):
|
||||
ne = self.top_k
|
||||
B, L, D = x.shape
|
||||
x = x.reshape(-1, D)
|
||||
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
gates = self.gate(x)
|
||||
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
|
||||
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne])
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1).astype(x.dtype)
|
||||
|
||||
if self.training:
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((B * L, ne, D), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt)
|
||||
|
||||
y = mx.stack(y, axis=0)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
shared_expert_output = self.shared_expert(x)
|
||||
shared_expert_output = (
|
||||
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
|
||||
)
|
||||
|
||||
y += shared_expert_output
|
||||
|
||||
return y.reshape(B, L, -1)
|
||||
return y + shared_expert_output
|
||||
|
||||
|
||||
class Qwen2MoeDecoderLayer(nn.Module):
|
||||
@ -243,12 +219,19 @@ class Model(nn.Module):
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
if to_join:
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
|
165
llms/mlx_lm/models/switch_layers.py
Normal file
165
llms/mlx_lm/models/switch_layers.py
Normal file
@ -0,0 +1,165 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class QuantizedSwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
num_experts: int,
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_experts, output_dims, input_dims),
|
||||
),
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((num_experts, output_dims))
|
||||
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
@property
|
||||
def input_dims(self):
|
||||
return self.scales.shape[2] * self.group_size
|
||||
|
||||
@property
|
||||
def output_dims(self):
|
||||
return self.weight.shape[1]
|
||||
|
||||
@property
|
||||
def num_experts(self):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.block_sparse_qmm(
|
||||
x,
|
||||
self["weight"],
|
||||
self["scales"],
|
||||
self["biases"],
|
||||
rhs_indices=indices,
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
||||
|
||||
class SwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_experts, output_dims, input_dims),
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((num_experts, output_dims))
|
||||
|
||||
@property
|
||||
def input_dims(self):
|
||||
return self.weight.shape[2]
|
||||
|
||||
@property
|
||||
def output_dims(self):
|
||||
return self.weight.shape[1]
|
||||
|
||||
@property
|
||||
def num_experts(self):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
num_experts, output_dims, input_dims = self.weight.shape
|
||||
ql = QuantizedSwitchLinear(
|
||||
input_dims, output_dims, num_experts, False, group_size, bits
|
||||
)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits)
|
||||
if "bias" in self:
|
||||
ql.bias = self.bias
|
||||
return ql
|
||||
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=nn.silu,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||
self.activation = activation
|
||||
|
||||
def __call__(self, x, indices) -> mx.array:
|
||||
x = mx.expand_dims(x, (-2, -3))
|
||||
|
||||
x_up = self.up_proj(x, indices)
|
||||
x_gate = self.gate_proj(x, indices)
|
||||
x = self.down_proj(self.activation(x_gate) * x_up, indices)
|
||||
|
||||
return x.squeeze(-2)
|
||||
|
||||
|
||||
class SwitchMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=nn.gelu_approx,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||
self.activation = activation
|
||||
|
||||
def __call__(self, x, indices) -> mx.array:
|
||||
x = mx.expand_dims(x, (-2, -3))
|
||||
|
||||
x = self.fc1(x, indices)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x, indices)
|
||||
|
||||
return x.squeeze(-2)
|
@ -5,6 +5,8 @@ import math
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
@ -100,3 +102,98 @@ class LoRALinear(nn.Module):
|
||||
y = self.linear(x)
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
return y + (self.scale * z).astype(x.dtype)
|
||||
|
||||
|
||||
class LoRASwitchLinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(
|
||||
linear: nn.Module,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
):
|
||||
lora_lin = LoRASwitchLinear(
|
||||
input_dims=linear.input_dims,
|
||||
output_dims=linear.output_dims,
|
||||
num_experts=linear.num_experts,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def to_linear(self, de_quantize: bool = False):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
is_quantized = isinstance(linear, QuantizedSwitchLinear)
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
if is_quantized:
|
||||
dtype = mx.float16
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
linear.scales,
|
||||
linear.biases,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
num_experts, output_dims, input_dims = weight.shape
|
||||
fused_linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias)
|
||||
|
||||
lora_b = (self.scale * self.lora_b).astype(dtype)
|
||||
lora_a = self.lora_a.reshape(num_experts, -1, input_dims).astype(dtype)
|
||||
fused_linear.weight = weight + lora_b @ lora_a
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
if is_quantized and not de_quantize:
|
||||
fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits)
|
||||
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
num_experts: int,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale * (alpha / r)
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(r * num_experts, input_dims),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(num_experts, output_dims, r))
|
||||
self.num_experts = num_experts
|
||||
|
||||
def __call__(self, x, indices):
|
||||
shape = x.shape[:-3] + (self.num_experts, -1)
|
||||
|
||||
y = self.linear(x, indices)
|
||||
z = (self.dropout(x) @ self.lora_a.T).reshape(shape)
|
||||
z = mx.take_along_axis(z, indices[..., None], axis=-2)
|
||||
z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1)
|
||||
|
||||
return y + (self.scale * z).astype(x.dtype)
|
||||
|
@ -9,8 +9,9 @@ import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||
from .dora import DoRALinear
|
||||
from .lora import LoRALinear
|
||||
from .lora import LoRALinear, LoRASwitchLinear
|
||||
|
||||
|
||||
def build_schedule(schedule_config: Dict):
|
||||
@ -58,11 +59,21 @@ def linear_to_lora_layers(
|
||||
f"Requested {num_lora_layers} LoRA layers "
|
||||
f"but the model only has {num_layers} layers."
|
||||
)
|
||||
cls = DoRALinear if use_dora else LoRALinear
|
||||
|
||||
def to_lora(lin):
|
||||
return cls.from_linear(
|
||||
lin,
|
||||
def to_lora(layer):
|
||||
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
|
||||
LoRALayer = DoRALinear if use_dora else LoRALinear
|
||||
elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)):
|
||||
if use_dora:
|
||||
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
|
||||
LoRALayer = LoRASwitchLinear
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can't convert layer of type {type(layer).__name__} to LoRA"
|
||||
)
|
||||
|
||||
return LoRALayer.from_linear(
|
||||
layer,
|
||||
r=config["rank"],
|
||||
alpha=config["alpha"],
|
||||
scale=config["scale"],
|
||||
|
@ -366,10 +366,11 @@ def load_model(
|
||||
|
||||
if (quantization := config.get("quantization", None)) is not None:
|
||||
# Handle legacy models which may not have everything quantized
|
||||
class_predicate = (
|
||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
def class_predicate(p, m):
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
return f"{p}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
**quantization,
|
||||
|
Loading…
Reference in New Issue
Block a user