mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
chore: fix the convert.py script for weights are not sanitized and support quant for non-32 dimensions (#340)
* chore: fix convert script for weights not sanitized and suport quant for non 32 dim * Update llms/mlx_lm/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: fix typo --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
61297f547b
commit
527cea4027
@ -10,7 +10,7 @@ import mlx.nn as nn
|
|||||||
import transformers
|
import transformers
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
from .utils import get_model_path, linear_class_predicate, load
|
from .utils import get_model_path, linear_class_predicate, load_model
|
||||||
|
|
||||||
MAX_FILE_SIZE_GB = 15
|
MAX_FILE_SIZE_GB = 15
|
||||||
|
|
||||||
@ -60,30 +60,23 @@ def fetch_from_hub(
|
|||||||
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
|
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
|
||||||
model_path = get_model_path(model_path)
|
model_path = get_model_path(model_path)
|
||||||
|
|
||||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
model = load_model(model_path)
|
||||||
if not weight_files:
|
|
||||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
|
||||||
|
|
||||||
weights = {}
|
|
||||||
for wf in weight_files:
|
|
||||||
weights.update(mx.load(wf).items())
|
|
||||||
|
|
||||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
return weights, config.to_dict(), tokenizer
|
return model, config.to_dict(), tokenizer
|
||||||
|
|
||||||
|
|
||||||
def quantize_model(
|
def quantize_model(
|
||||||
weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int
|
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Applies quantization to the model weights.
|
Applies quantization to the model weights.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
weights (dict): Model weights.
|
model (nn.Module): The model to be quantized.
|
||||||
config (dict): Model configuration.
|
config (dict): Model configuration.
|
||||||
hf_path (str): HF model path..
|
|
||||||
q_group_size (int): Group size for quantization.
|
q_group_size (int): Group size for quantization.
|
||||||
q_bits (int): Bits per weight for quantization.
|
q_bits (int): Bits per weight for quantization.
|
||||||
|
|
||||||
@ -91,8 +84,6 @@ def quantize_model(
|
|||||||
tuple: Tuple containing quantized weights and config.
|
tuple: Tuple containing quantized weights and config.
|
||||||
"""
|
"""
|
||||||
quantized_config = copy.deepcopy(config)
|
quantized_config = copy.deepcopy(config)
|
||||||
model, _ = load(hf_path)
|
|
||||||
model.load_weights(list(weights.items()))
|
|
||||||
|
|
||||||
nn.QuantizedLinear.quantize_module(
|
nn.QuantizedLinear.quantize_module(
|
||||||
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
||||||
@ -183,12 +174,16 @@ def convert(
|
|||||||
upload_repo: str = None,
|
upload_repo: str = None,
|
||||||
):
|
):
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
weights, config, tokenizer = fetch_from_hub(hf_path)
|
model, config, tokenizer = fetch_from_hub(hf_path)
|
||||||
|
|
||||||
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||||
|
|
||||||
if quantize:
|
if quantize:
|
||||||
print("[INFO] Quantizing")
|
print("[INFO] Quantizing")
|
||||||
weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits)
|
model.load_weights(list(weights.items()))
|
||||||
|
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||||
|
|
||||||
mlx_path = Path(mlx_path)
|
mlx_path = Path(mlx_path)
|
||||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -22,8 +22,10 @@ MODEL_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
linear_class_predicate = (
|
linear_class_predicate = (
|
||||||
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] % 32 == 0
|
lambda m: isinstance(m, nn.Linear)
|
||||||
) # TODO remove this once we support quantization for non-multiples of 32
|
and m.weight.shape[0]
|
||||||
|
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_classes(config: dict):
|
def _get_classes(config: dict):
|
||||||
@ -142,22 +144,20 @@ def generate(
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
def load_model(model_path: Path) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Load the model from a given path or a huggingface repository.
|
Load and initialize the model from a given path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path_or_hf_repo (str): The path or the huggingface repository to load the model from.
|
model_path (Path): The path to load the model from.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
|
nn.Module: The loaded and initialized model.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If config file or safetensors are not found.
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||||
ValueError: If model class or args class are not found.
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(model_path / "config.json", "r") as f:
|
with open(model_path / "config.json", "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
@ -165,10 +165,12 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logging.error(f"Config file not found in {model_path}")
|
logging.error(f"Config file not found in {model_path}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||||
if not weight_files:
|
if not weight_files:
|
||||||
logging.error(f"No safetensors found in {model_path}")
|
logging.error(f"No safetensors found in {model_path}")
|
||||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||||
|
|
||||||
weights = {}
|
weights = {}
|
||||||
for wf in weight_files:
|
for wf in weight_files:
|
||||||
weights.update(mx.load(wf))
|
weights.update(mx.load(wf))
|
||||||
@ -190,5 +192,26 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
|||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||||
|
"""
|
||||||
|
Load the model from a given path or a huggingface repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_or_hf_repo (str): The path or the huggingface repository to load the model from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If config file or safetensors are not found.
|
||||||
|
ValueError: If model class or args class are not found.
|
||||||
|
"""
|
||||||
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
|
model = load_model(model_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
Loading…
Reference in New Issue
Block a user