chore(lora): support mixtral in lora example (#343)

This commit is contained in:
Anchen
2024-01-20 06:07:45 -08:00
committed by GitHub
parent 527cea4027
commit 1415595409
6 changed files with 279 additions and 4 deletions

View File

@@ -9,6 +9,7 @@ from typing import Generator
import mlx.core as mx
import mlx.nn as nn
import models.llama as llama
import models.mixtral as mixtral
import models.phi2 as phi2
import transformers
from huggingface_hub import snapshot_download
@@ -18,6 +19,7 @@ MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"phi": phi2,
"mixtral": mixtral,
}
@@ -150,7 +152,12 @@ def load(path_or_hf_repo: str):
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)
model.load_weights(list(weights.items()))