From 527cea4027974b6a44d3d16d62b385f10dcbcb65 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 19 Jan 2024 21:07:21 -0800
Subject: [PATCH] chore: fix the convert.py script for weights are not
sanitized and support quant for non-32 dimensions (#340)
* chore: fix convert script for weights not sanitized and suport quant for non 32 dim
* Update llms/mlx_lm/utils.py
Co-authored-by: Awni Hannun
* chore: fix typo
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/convert.py | 27 +++++++++++---------------
llms/mlx_lm/utils.py | 43 ++++++++++++++++++++++++++++++++----------
2 files changed, 44 insertions(+), 26 deletions(-)
diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py
index 91a0aa42..ead4f0e4 100644
--- a/llms/mlx_lm/convert.py
+++ b/llms/mlx_lm/convert.py
@@ -10,7 +10,7 @@ import mlx.nn as nn
import transformers
from mlx.utils import tree_flatten
-from .utils import get_model_path, linear_class_predicate, load
+from .utils import get_model_path, linear_class_predicate, load_model
MAX_FILE_SIZE_GB = 15
@@ -60,30 +60,23 @@ def fetch_from_hub(
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
model_path = get_model_path(model_path)
- weight_files = glob.glob(f"{model_path}/*.safetensors")
- if not weight_files:
- raise FileNotFoundError(f"No safetensors found in {model_path}")
-
- weights = {}
- for wf in weight_files:
- weights.update(mx.load(wf).items())
+ model = load_model(model_path)
config = transformers.AutoConfig.from_pretrained(model_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
- return weights, config.to_dict(), tokenizer
+ return model, config.to_dict(), tokenizer
def quantize_model(
- weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int
+ model: nn.Module, config: dict, q_group_size: int, q_bits: int
) -> tuple:
"""
Applies quantization to the model weights.
Args:
- weights (dict): Model weights.
+ model (nn.Module): The model to be quantized.
config (dict): Model configuration.
- hf_path (str): HF model path..
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
@@ -91,8 +84,6 @@ def quantize_model(
tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
- model, _ = load(hf_path)
- model.load_weights(list(weights.items()))
nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
@@ -183,12 +174,16 @@ def convert(
upload_repo: str = None,
):
print("[INFO] Loading")
- weights, config, tokenizer = fetch_from_hub(hf_path)
+ model, config, tokenizer = fetch_from_hub(hf_path)
+
+ weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
+
if quantize:
print("[INFO] Quantizing")
- weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits)
+ model.load_weights(list(weights.items()))
+ weights, config = quantize_model(model, config, q_group_size, q_bits)
mlx_path = Path(mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 692326fe..a7eaea52 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -22,8 +22,10 @@ MODEL_MAPPING = {
}
linear_class_predicate = (
- lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] % 32 == 0
-) # TODO remove this once we support quantization for non-multiples of 32
+ lambda m: isinstance(m, nn.Linear)
+ and m.weight.shape[0]
+ != 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
+)
def _get_classes(config: dict):
@@ -142,22 +144,20 @@ def generate(
return tokens
-def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
+def load_model(model_path: Path) -> nn.Module:
"""
- Load the model from a given path or a huggingface repository.
+ Load and initialize the model from a given path.
Args:
- path_or_hf_repo (str): The path or the huggingface repository to load the model from.
+ model_path (Path): The path to load the model from.
Returns:
- Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
+ nn.Module: The loaded and initialized model.
Raises:
- FileNotFoundError: If config file or safetensors are not found.
- ValueError: If model class or args class are not found.
+ FileNotFoundError: If the weight files (.safetensors) are not found.
+ ValueError: If the model class or args class are not found or cannot be instantiated.
"""
- model_path = get_model_path(path_or_hf_repo)
-
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
@@ -165,10 +165,12 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
+
weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
+
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
@@ -190,5 +192,26 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
+
+ return model
+
+
+def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
+ """
+ Load the model from a given path or a huggingface repository.
+
+ Args:
+ path_or_hf_repo (str): The path or the huggingface repository to load the model from.
+
+ Returns:
+ Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
+
+ Raises:
+ FileNotFoundError: If config file or safetensors are not found.
+ ValueError: If model class or args class are not found.
+ """
+ model_path = get_model_path(path_or_hf_repo)
+
+ model = load_model(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer