mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
chore: fix the convert.py script for weights are not sanitized and support quant for non-32 dimensions (#340)
* chore: fix convert script for weights not sanitized and suport quant for non 32 dim * Update llms/mlx_lm/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * chore: fix typo --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -22,8 +22,10 @@ MODEL_MAPPING = {
|
||||
}
|
||||
|
||||
linear_class_predicate = (
|
||||
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] % 32 == 0
|
||||
) # TODO remove this once we support quantization for non-multiples of 32
|
||||
lambda m: isinstance(m, nn.Linear)
|
||||
and m.weight.shape[0]
|
||||
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
|
||||
)
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
@@ -142,22 +144,20 @@ def generate(
|
||||
return tokens
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
def load_model(model_path: Path) -> nn.Module:
|
||||
"""
|
||||
Load the model from a given path or a huggingface repository.
|
||||
Load and initialize the model from a given path.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo (str): The path or the huggingface repository to load the model from.
|
||||
model_path (Path): The path to load the model from.
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
|
||||
nn.Module: The loaded and initialized model.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
ValueError: If model class or args class are not found.
|
||||
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
try:
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
@@ -165,10 +165,12 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Config file not found in {model_path}")
|
||||
raise
|
||||
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
if not weight_files:
|
||||
logging.error(f"No safetensors found in {model_path}")
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
@@ -190,5 +192,26 @@ def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
mx.eval(model.parameters())
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
"""
|
||||
Load the model from a given path or a huggingface repository.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo (str): The path or the huggingface repository to load the model from.
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
ValueError: If model class or args class are not found.
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
model = load_model(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
return model, tokenizer
|
||||
|
||||
Reference in New Issue
Block a user