mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fix import warning (#479)
* fix import warning * fix version import * remove api, move convert to utils * also update circle to run external PRs
This commit is contained in:
parent
82f3f31d93
commit
95f82e67a2
@ -1,5 +1,8 @@
|
|||||||
version: 2.1
|
version: 2.1
|
||||||
|
|
||||||
|
orbs:
|
||||||
|
apple: ml-explore/pr-approval@0.1.0
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
docker:
|
docker:
|
||||||
@ -31,5 +34,7 @@ workflows:
|
|||||||
jobs:
|
jobs:
|
||||||
- hold:
|
- hold:
|
||||||
type: approval
|
type: approval
|
||||||
|
- apple/authenticate:
|
||||||
|
context: pr-approval
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .convert import convert
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
from .utils import generate, load
|
|
||||||
|
|
||||||
__version__ = "0.0.14"
|
from .utils import convert, generate, load
|
||||||
|
from .version import __version__
|
||||||
|
@ -1,22 +1,8 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
|
||||||
import glob
|
|
||||||
import json
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
from .utils import convert
|
||||||
import mlx.nn as nn
|
|
||||||
from mlx.utils import tree_flatten
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
fetch_from_hub,
|
|
||||||
get_model_path,
|
|
||||||
linear_class_predicate,
|
|
||||||
save_weights,
|
|
||||||
upload_to_hub,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def configure_parser() -> argparse.ArgumentParser:
|
def configure_parser() -> argparse.ArgumentParser:
|
||||||
@ -59,73 +45,6 @@ def configure_parser() -> argparse.ArgumentParser:
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def quantize_model(
|
|
||||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
|
||||||
) -> Tuple:
|
|
||||||
"""
|
|
||||||
Applies quantization to the model weights.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): The model to be quantized.
|
|
||||||
config (dict): Model configuration.
|
|
||||||
q_group_size (int): Group size for quantization.
|
|
||||||
q_bits (int): Bits per weight for quantization.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple: Tuple containing quantized weights and config.
|
|
||||||
"""
|
|
||||||
quantized_config = copy.deepcopy(config)
|
|
||||||
|
|
||||||
nn.QuantizedLinear.quantize_module(
|
|
||||||
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
|
||||||
)
|
|
||||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
|
||||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
||||||
|
|
||||||
return quantized_weights, quantized_config
|
|
||||||
|
|
||||||
|
|
||||||
def convert(
|
|
||||||
hf_path: str,
|
|
||||||
mlx_path: str = "mlx_model",
|
|
||||||
quantize: bool = False,
|
|
||||||
q_group_size: int = 64,
|
|
||||||
q_bits: int = 4,
|
|
||||||
dtype: str = "float16",
|
|
||||||
upload_repo: str = None,
|
|
||||||
):
|
|
||||||
print("[INFO] Loading")
|
|
||||||
model_path = get_model_path(hf_path)
|
|
||||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
|
||||||
|
|
||||||
weights = dict(tree_flatten(model.parameters()))
|
|
||||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
|
||||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
|
||||||
|
|
||||||
if quantize:
|
|
||||||
print("[INFO] Quantizing")
|
|
||||||
model.load_weights(list(weights.items()))
|
|
||||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
|
||||||
|
|
||||||
if isinstance(mlx_path, str):
|
|
||||||
mlx_path = Path(mlx_path)
|
|
||||||
|
|
||||||
del model
|
|
||||||
save_weights(mlx_path, weights, donate_weights=True)
|
|
||||||
|
|
||||||
py_files = glob.glob(str(model_path / "*.py"))
|
|
||||||
for file in py_files:
|
|
||||||
shutil.copy(file, mlx_path)
|
|
||||||
|
|
||||||
tokenizer.save_pretrained(mlx_path)
|
|
||||||
|
|
||||||
with open(mlx_path / "config.json", "w") as fid:
|
|
||||||
json.dump(config, fid, indent=4)
|
|
||||||
|
|
||||||
if upload_repo is not None:
|
|
||||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = configure_parser()
|
parser = configure_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||||
@ -11,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
@ -515,3 +519,70 @@ def save_weights(
|
|||||||
f,
|
f,
|
||||||
indent=4,
|
indent=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_model(
|
||||||
|
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||||
|
) -> Tuple:
|
||||||
|
"""
|
||||||
|
Applies quantization to the model weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to be quantized.
|
||||||
|
config (dict): Model configuration.
|
||||||
|
q_group_size (int): Group size for quantization.
|
||||||
|
q_bits (int): Bits per weight for quantization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple: Tuple containing quantized weights and config.
|
||||||
|
"""
|
||||||
|
quantized_config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
nn.QuantizedLinear.quantize_module(
|
||||||
|
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
||||||
|
)
|
||||||
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||||
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
return quantized_weights, quantized_config
|
||||||
|
|
||||||
|
|
||||||
|
def convert(
|
||||||
|
hf_path: str,
|
||||||
|
mlx_path: str = "mlx_model",
|
||||||
|
quantize: bool = False,
|
||||||
|
q_group_size: int = 64,
|
||||||
|
q_bits: int = 4,
|
||||||
|
dtype: str = "float16",
|
||||||
|
upload_repo: str = None,
|
||||||
|
):
|
||||||
|
print("[INFO] Loading")
|
||||||
|
model_path = get_model_path(hf_path)
|
||||||
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||||
|
|
||||||
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||||
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||||
|
|
||||||
|
if quantize:
|
||||||
|
print("[INFO] Quantizing")
|
||||||
|
model.load_weights(list(weights.items()))
|
||||||
|
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||||
|
|
||||||
|
if isinstance(mlx_path, str):
|
||||||
|
mlx_path = Path(mlx_path)
|
||||||
|
|
||||||
|
del model
|
||||||
|
save_weights(mlx_path, weights, donate_weights=True)
|
||||||
|
|
||||||
|
py_files = glob.glob(str(model_path / "*.py"))
|
||||||
|
for file in py_files:
|
||||||
|
shutil.copy(file, mlx_path)
|
||||||
|
|
||||||
|
tokenizer.save_pretrained(mlx_path)
|
||||||
|
|
||||||
|
with open(mlx_path / "config.json", "w") as fid:
|
||||||
|
json.dump(config, fid, indent=4)
|
||||||
|
|
||||||
|
if upload_repo is not None:
|
||||||
|
upload_to_hub(mlx_path, upload_repo, hf_path)
|
||||||
|
3
llms/mlx_lm/version.py
Normal file
3
llms/mlx_lm/version.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
__version__ = "0.0.14"
|
@ -1,15 +1,18 @@
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx_lm
|
|
||||||
import pkg_resources
|
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
|
package_dir = Path(__file__).parent / "mlx_lm"
|
||||||
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
|
with open(package_dir / "requirements.txt") as fid:
|
||||||
|
requirements = [l.strip() for l in fid.readlines()]
|
||||||
|
|
||||||
|
sys.path.append(str(package_dir))
|
||||||
|
from version import __version__
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx-lm",
|
name="mlx-lm",
|
||||||
version=mlx_lm.__version__,
|
version=__version__,
|
||||||
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
|
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
|
||||||
long_description=open("README.md", encoding="utf-8").read(),
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
Loading…
Reference in New Issue
Block a user