Block sparse MM MoEs (#782)

- Adds SwitchLinear
- Adds QuantizedSwitchLinear
This commit is contained in:
Angelos Katharopoulos
2024-05-21 15:58:08 -07:00
committed by GitHub
parent 199df9e110
commit 9f671228cd
8 changed files with 365 additions and 143 deletions

View File

@@ -366,10 +366,11 @@ def load_model(
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
def class_predicate(p, m):
if not hasattr(m, "to_quantized"):
return False
return f"{p}.scales" in weights
nn.quantize(
model,
**quantization,