Block sparse MM MoEs (#782)

- Adds SwitchLinear
- Adds QuantizedSwitchLinear
This commit is contained in:
Angelos Katharopoulos
2024-05-21 15:58:08 -07:00
committed by GitHub
parent 199df9e110
commit 9f671228cd
8 changed files with 365 additions and 143 deletions

View File

@@ -9,8 +9,9 @@ import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRALinear
from .lora import LoRALinear
from .lora import LoRALinear, LoRASwitchLinear
def build_schedule(schedule_config: Dict):
@@ -58,11 +59,21 @@ def linear_to_lora_layers(
f"Requested {num_lora_layers} LoRA layers "
f"but the model only has {num_layers} layers."
)
cls = DoRALinear if use_dora else LoRALinear
def to_lora(lin):
return cls.from_linear(
lin,
def to_lora(layer):
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
LoRALayer = DoRALinear if use_dora else LoRALinear
elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)):
if use_dora:
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
LoRALayer = LoRASwitchLinear
else:
raise ValueError(
f"Can't convert layer of type {type(layer).__name__} to LoRA"
)
return LoRALayer.from_linear(
layer,
r=config["rank"],
alpha=config["alpha"],
scale=config["scale"],