Block sparse MM MoEs (#782)

- Adds SwitchLinear
- Adds QuantizedSwitchLinear
This commit is contained in:
Angelos Katharopoulos
2024-05-21 15:58:08 -07:00
committed by GitHub
parent 199df9e110
commit 9f671228cd
8 changed files with 365 additions and 143 deletions

View File

@@ -7,7 +7,7 @@ from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.dora import DoRALinear
from .tuner.lora import LoRALinear
from .tuner.lora import LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize
from .utils import (
fetch_from_hub,
@@ -82,7 +82,7 @@ def main() -> None:
fused_linears = [
(n, m.to_linear())
for n, m in model.named_modules()
if isinstance(m, (LoRALinear, DoRALinear))
if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear))
]
model.update_modules(tree_unflatten(fused_linears))