mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Fix import warning (#479)
* fix import warning * fix version import * remove api, move convert to utils * also update circle to run external PRs
This commit is contained in:
parent
82f3f31d93
commit
95f82e67a2
@ -1,5 +1,8 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
apple: ml-explore/pr-approval@0.1.0
|
||||
|
||||
jobs:
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
@ -31,5 +34,7 @@ workflows:
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .convert import convert
|
||||
from .utils import generate, load
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.0.14"
|
||||
from .utils import convert, generate, load
|
||||
from .version import __version__
|
||||
|
@ -1,22 +1,8 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
linear_class_predicate,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
from .utils import convert
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
@ -59,73 +45,6 @@ def configure_parser() -> argparse.ArgumentParser:
|
||||
return parser
|
||||
|
||||
|
||||
def quantize_model(
|
||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||
) -> Tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be quantized.
|
||||
config (dict): Model configuration.
|
||||
q_group_size (int): Group size for quantization.
|
||||
q_bits (int): Bits per weight for quantization.
|
||||
|
||||
Returns:
|
||||
Tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
||||
)
|
||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def convert(
|
||||
hf_path: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
quantize: bool = False,
|
||||
q_group_size: int = 64,
|
||||
q_bits: int = 4,
|
||||
dtype: str = "float16",
|
||||
upload_repo: str = None,
|
||||
):
|
||||
print("[INFO] Loading")
|
||||
model_path = get_model_path(hf_path)
|
||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
|
||||
if quantize:
|
||||
print("[INFO] Quantizing")
|
||||
model.load_weights(list(weights.items()))
|
||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||
|
||||
if isinstance(mlx_path, str):
|
||||
mlx_path = Path(mlx_path)
|
||||
|
||||
del model
|
||||
save_weights(mlx_path, weights, donate_weights=True)
|
||||
|
||||
py_files = glob.glob(str(model_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, mlx_path)
|
||||
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
|
@ -1,3 +1,5 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
|
@ -1,3 +1,5 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
|
@ -1,3 +1,5 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
@ -1,9 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
@ -11,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
@ -515,3 +519,70 @@ def save_weights(
|
||||
f,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
|
||||
def quantize_model(
|
||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||
) -> Tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be quantized.
|
||||
config (dict): Model configuration.
|
||||
q_group_size (int): Group size for quantization.
|
||||
q_bits (int): Bits per weight for quantization.
|
||||
|
||||
Returns:
|
||||
Tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
||||
)
|
||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def convert(
|
||||
hf_path: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
quantize: bool = False,
|
||||
q_group_size: int = 64,
|
||||
q_bits: int = 4,
|
||||
dtype: str = "float16",
|
||||
upload_repo: str = None,
|
||||
):
|
||||
print("[INFO] Loading")
|
||||
model_path = get_model_path(hf_path)
|
||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
|
||||
if quantize:
|
||||
print("[INFO] Quantizing")
|
||||
model.load_weights(list(weights.items()))
|
||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||
|
||||
if isinstance(mlx_path, str):
|
||||
mlx_path = Path(mlx_path)
|
||||
|
||||
del model
|
||||
save_weights(mlx_path, weights, donate_weights=True)
|
||||
|
||||
py_files = glob.glob(str(model_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, mlx_path)
|
||||
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
||||
|
3
llms/mlx_lm/version.py
Normal file
3
llms/mlx_lm/version.py
Normal file
@ -0,0 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.0.14"
|
@ -1,15 +1,18 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import mlx_lm
|
||||
import pkg_resources
|
||||
from setuptools import setup
|
||||
|
||||
with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
|
||||
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
|
||||
package_dir = Path(__file__).parent / "mlx_lm"
|
||||
with open(package_dir / "requirements.txt") as fid:
|
||||
requirements = [l.strip() for l in fid.readlines()]
|
||||
|
||||
sys.path.append(str(package_dir))
|
||||
from version import __version__
|
||||
|
||||
setup(
|
||||
name="mlx-lm",
|
||||
version=mlx_lm.__version__,
|
||||
version=__version__,
|
||||
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
Loading…
Reference in New Issue
Block a user