mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-11-03 13:08:08 +08:00
chore(lora): support mixtral in lora example (#343)
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import Generator
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import models.llama as llama
|
||||
import models.mixtral as mixtral
|
||||
import models.phi2 as phi2
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
@@ -18,6 +19,7 @@ MODEL_MAPPING = {
|
||||
"llama": llama,
|
||||
"mistral": llama, # mistral is compatible with llama
|
||||
"phi": phi2,
|
||||
"mixtral": mixtral,
|
||||
}
|
||||
|
||||
|
||||
@@ -150,7 +152,12 @@ def load(path_or_hf_repo: str):
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = model_class(model_args)
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model,
|
||||
**quantization,
|
||||
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
|
||||
and m.weight.shape[0] != 8,
|
||||
)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user