mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Block sparse MM MoEs (#782)
- Adds SwitchLinear - Adds QuantizedSwitchLinear
This commit is contained in:

committed by
GitHub

parent
199df9e110
commit
9f671228cd
@@ -9,8 +9,9 @@ import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||
from .dora import DoRALinear
|
||||
from .lora import LoRALinear
|
||||
from .lora import LoRALinear, LoRASwitchLinear
|
||||
|
||||
|
||||
def build_schedule(schedule_config: Dict):
|
||||
@@ -58,11 +59,21 @@ def linear_to_lora_layers(
|
||||
f"Requested {num_lora_layers} LoRA layers "
|
||||
f"but the model only has {num_layers} layers."
|
||||
)
|
||||
cls = DoRALinear if use_dora else LoRALinear
|
||||
|
||||
def to_lora(lin):
|
||||
return cls.from_linear(
|
||||
lin,
|
||||
def to_lora(layer):
|
||||
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
|
||||
LoRALayer = DoRALinear if use_dora else LoRALinear
|
||||
elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)):
|
||||
if use_dora:
|
||||
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
|
||||
LoRALayer = LoRASwitchLinear
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can't convert layer of type {type(layer).__name__} to LoRA"
|
||||
)
|
||||
|
||||
return LoRALayer.from_linear(
|
||||
layer,
|
||||
r=config["rank"],
|
||||
alpha=config["alpha"],
|
||||
scale=config["scale"],
|
||||
|
Reference in New Issue
Block a user