mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
chore: fix the convert.py script for weights are not sanitized and support quant for non-32 dimensions (#340)
* chore: fix convert script for weights not sanitized and suport quant for non 32 dim * Update llms/mlx_lm/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: fix typo --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import mlx.nn as nn
|
||||
import transformers
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .utils import get_model_path, linear_class_predicate, load
|
||||
from .utils import get_model_path, linear_class_predicate, load_model
|
||||
|
||||
MAX_FILE_SIZE_GB = 15
|
||||
|
||||
@@ -60,30 +60,23 @@ def fetch_from_hub(
|
||||
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
|
||||
model_path = get_model_path(model_path)
|
||||
|
||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
model = load_model(model_path)
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
return weights, config.to_dict(), tokenizer
|
||||
return model, config.to_dict(), tokenizer
|
||||
|
||||
|
||||
def quantize_model(
|
||||
weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int
|
||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||
) -> tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
model (nn.Module): The model to be quantized.
|
||||
config (dict): Model configuration.
|
||||
hf_path (str): HF model path..
|
||||
q_group_size (int): Group size for quantization.
|
||||
q_bits (int): Bits per weight for quantization.
|
||||
|
||||
@@ -91,8 +84,6 @@ def quantize_model(
|
||||
tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
model, _ = load(hf_path)
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
||||
@@ -183,12 +174,16 @@ def convert(
|
||||
upload_repo: str = None,
|
||||
):
|
||||
print("[INFO] Loading")
|
||||
weights, config, tokenizer = fetch_from_hub(hf_path)
|
||||
model, config, tokenizer = fetch_from_hub(hf_path)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
|
||||
if quantize:
|
||||
print("[INFO] Quantizing")
|
||||
weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits)
|
||||
model.load_weights(list(weights.items()))
|
||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||
|
||||
mlx_path = Path(mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user