Fix import warning (#479)

* fix import warning
* fix version import
* remove api, move convert to utils
* also update circle to run external PRs
This commit is contained in:
Awni Hannun 2024-02-27 08:47:56 -08:00 committed by GitHub
parent 82f3f31d93
commit 95f82e67a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 99 additions and 92 deletions

View File

@ -1,5 +1,8 @@
version: 2.1 version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
jobs: jobs:
linux_build_and_test: linux_build_and_test:
docker: docker:
@ -31,5 +34,7 @@ workflows:
jobs: jobs:
- hold: - hold:
type: approval type: approval
- apple/authenticate:
context: pr-approval
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]

View File

@ -1,4 +1,4 @@
from .convert import convert # Copyright © 2023-2024 Apple Inc.
from .utils import generate, load
__version__ = "0.0.14" from .utils import convert, generate, load
from .version import __version__

View File

@ -1,22 +1,8 @@
# Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import copy
import glob
import json
import shutil
from pathlib import Path
from typing import Tuple
import mlx.core as mx from .utils import convert
import mlx.nn as nn
from mlx.utils import tree_flatten
from .utils import (
fetch_from_hub,
get_model_path,
linear_class_predicate,
save_weights,
upload_to_hub,
)
def configure_parser() -> argparse.ArgumentParser: def configure_parser() -> argparse.ArgumentParser:
@ -59,73 +45,6 @@ def configure_parser() -> argparse.ArgumentParser:
return parser return parser
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
) -> Tuple:
"""
Applies quantization to the model weights.
Args:
model (nn.Module): The model to be quantized.
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
Returns:
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
):
print("[INFO] Loading")
model_path = get_model_path(hf_path)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)
del model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)
if __name__ == "__main__": if __name__ == "__main__":
parser = configure_parser() parser = configure_parser()
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import mlx.core as mx import mlx.core as mx

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import glob import glob
import json import json

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import json import json
import time import time

View File

@ -1,9 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
import copy import copy
import gc import gc
import glob import glob
import importlib import importlib
import json import json
import logging import logging
import shutil
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
@ -11,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
# Local imports # Local imports
@ -515,3 +519,70 @@ def save_weights(
f, f,
indent=4, indent=4,
) )
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
) -> Tuple:
"""
Applies quantization to the model weights.
Args:
model (nn.Module): The model to be quantized.
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
Returns:
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
):
print("[INFO] Loading")
model_path = get_model_path(hf_path)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)
del model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)

3
llms/mlx_lm/version.py Normal file
View File

@ -0,0 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.0.14"

View File

@ -1,15 +1,18 @@
import sys import sys
from pathlib import Path from pathlib import Path
import mlx_lm
import pkg_resources
from setuptools import setup from setuptools import setup
with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid: package_dir = Path(__file__).parent / "mlx_lm"
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)] with open(package_dir / "requirements.txt") as fid:
requirements = [l.strip() for l in fid.readlines()]
sys.path.append(str(package_dir))
from version import __version__
setup( setup(
name="mlx-lm", name="mlx-lm",
version=mlx_lm.__version__, version=__version__,
description="LLMs on Apple silicon with MLX and the Hugging Face Hub", description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
long_description=open("README.md", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",