feat(mlx_lm): add mixtral support in mlx_lm (#318)

* feat: add mixtral support in mlx_lm

* chore: update doc
This commit is contained in:
Anchen
2024-01-15 07:18:14 -08:00
committed by GitHub
parent 19b6167d81
commit 195bec2fa3
4 changed files with 266 additions and 9 deletions

View File

@@ -10,16 +10,21 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports
from .models import llama, phi2
from .models import llama, mixtral, phi2
from .models.base import BaseModelArgs
# Constants
MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"mixtral": mixtral,
"phi": phi2,
}
linear_class_predicate = (
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] % 32 == 0
) # TODO remove this once we support quantization for non-multiples of 32
def _get_classes(config: dict):
"""
@@ -171,7 +176,11 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
model = model_class(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=linear_class_predicate,
)
model.load_weights(list(weights.items()))