mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Block sparse MM MoEs (#782)
- Adds SwitchLinear - Adds QuantizedSwitchLinear
This commit is contained in:
committed by
GitHub
parent
199df9e110
commit
9f671228cd
165
llms/mlx_lm/models/switch_layers.py
Normal file
165
llms/mlx_lm/models/switch_layers.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class QuantizedSwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
num_experts: int,
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_experts, output_dims, input_dims),
|
||||
),
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((num_experts, output_dims))
|
||||
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
@property
|
||||
def input_dims(self):
|
||||
return self.scales.shape[2] * self.group_size
|
||||
|
||||
@property
|
||||
def output_dims(self):
|
||||
return self.weight.shape[1]
|
||||
|
||||
@property
|
||||
def num_experts(self):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.block_sparse_qmm(
|
||||
x,
|
||||
self["weight"],
|
||||
self["scales"],
|
||||
self["biases"],
|
||||
rhs_indices=indices,
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
||||
|
||||
class SwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_experts, output_dims, input_dims),
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((num_experts, output_dims))
|
||||
|
||||
@property
|
||||
def input_dims(self):
|
||||
return self.weight.shape[2]
|
||||
|
||||
@property
|
||||
def output_dims(self):
|
||||
return self.weight.shape[1]
|
||||
|
||||
@property
|
||||
def num_experts(self):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
num_experts, output_dims, input_dims = self.weight.shape
|
||||
ql = QuantizedSwitchLinear(
|
||||
input_dims, output_dims, num_experts, False, group_size, bits
|
||||
)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits)
|
||||
if "bias" in self:
|
||||
ql.bias = self.bias
|
||||
return ql
|
||||
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=nn.silu,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||
self.activation = activation
|
||||
|
||||
def __call__(self, x, indices) -> mx.array:
|
||||
x = mx.expand_dims(x, (-2, -3))
|
||||
|
||||
x_up = self.up_proj(x, indices)
|
||||
x_gate = self.gate_proj(x, indices)
|
||||
x = self.down_proj(self.activation(x_gate) * x_up, indices)
|
||||
|
||||
return x.squeeze(-2)
|
||||
|
||||
|
||||
class SwitchMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=nn.gelu_approx,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||
self.activation = activation
|
||||
|
||||
def __call__(self, x, indices) -> mx.array:
|
||||
x = mx.expand_dims(x, (-2, -3))
|
||||
|
||||
x = self.fc1(x, indices)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x, indices)
|
||||
|
||||
return x.squeeze(-2)
|
||||
Reference in New Issue
Block a user