diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 91a0aa42..ead4f0e4 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -10,7 +10,7 @@ import mlx.nn as nn import transformers from mlx.utils import tree_flatten -from .utils import get_model_path, linear_class_predicate, load +from .utils import get_model_path, linear_class_predicate, load_model MAX_FILE_SIZE_GB = 15 @@ -60,30 +60,23 @@ def fetch_from_hub( ) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]: model_path = get_model_path(model_path) - weight_files = glob.glob(f"{model_path}/*.safetensors") - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {model_path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf).items()) + model = load_model(model_path) config = transformers.AutoConfig.from_pretrained(model_path) tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) - return weights, config.to_dict(), tokenizer + return model, config.to_dict(), tokenizer def quantize_model( - weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int + model: nn.Module, config: dict, q_group_size: int, q_bits: int ) -> tuple: """ Applies quantization to the model weights. Args: - weights (dict): Model weights. + model (nn.Module): The model to be quantized. config (dict): Model configuration. - hf_path (str): HF model path.. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. @@ -91,8 +84,6 @@ def quantize_model( tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - model, _ = load(hf_path) - model.load_weights(list(weights.items())) nn.QuantizedLinear.quantize_module( model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate @@ -183,12 +174,16 @@ def convert( upload_repo: str = None, ): print("[INFO] Loading") - weights, config, tokenizer = fetch_from_hub(hf_path) + model, config, tokenizer = fetch_from_hub(hf_path) + + weights = dict(tree_flatten(model.parameters())) dtype = mx.float16 if quantize else getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} + if quantize: print("[INFO] Quantizing") - weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits) + model.load_weights(list(weights.items())) + weights, config = quantize_model(model, config, q_group_size, q_bits) mlx_path = Path(mlx_path) mlx_path.mkdir(parents=True, exist_ok=True) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 692326fe..a7eaea52 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -22,8 +22,10 @@ MODEL_MAPPING = { } linear_class_predicate = ( - lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] % 32 == 0 -) # TODO remove this once we support quantization for non-multiples of 32 + lambda m: isinstance(m, nn.Linear) + and m.weight.shape[0] + != 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models +) def _get_classes(config: dict): @@ -142,22 +144,20 @@ def generate( return tokens -def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: +def load_model(model_path: Path) -> nn.Module: """ - Load the model from a given path or a huggingface repository. + Load and initialize the model from a given path. Args: - path_or_hf_repo (str): The path or the huggingface repository to load the model from. + model_path (Path): The path to load the model from. Returns: - Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer. + nn.Module: The loaded and initialized model. Raises: - FileNotFoundError: If config file or safetensors are not found. - ValueError: If model class or args class are not found. + FileNotFoundError: If the weight files (.safetensors) are not found. + ValueError: If the model class or args class are not found or cannot be instantiated. """ - model_path = get_model_path(path_or_hf_repo) - try: with open(model_path / "config.json", "r") as f: config = json.load(f) @@ -165,10 +165,12 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: except FileNotFoundError: logging.error(f"Config file not found in {model_path}") raise + weight_files = glob.glob(str(model_path / "*.safetensors")) if not weight_files: logging.error(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}") + weights = {} for wf in weight_files: weights.update(mx.load(wf)) @@ -190,5 +192,26 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: model.load_weights(list(weights.items())) mx.eval(model.parameters()) + + return model + + +def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: + """ + Load the model from a given path or a huggingface repository. + + Args: + path_or_hf_repo (str): The path or the huggingface repository to load the model from. + + Returns: + Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer. + + Raises: + FileNotFoundError: If config file or safetensors are not found. + ValueError: If model class or args class are not found. + """ + model_path = get_model_path(path_or_hf_repo) + + model = load_model(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) return model, tokenizer