Fix import warning (#479)

* fix import warning
* fix version import
* remove api, move convert to utils
* also update circle to run external PRs
This commit is contained in:
Awni Hannun 2024-02-27 08:47:56 -08:00 committed by GitHub
parent 82f3f31d93
commit 95f82e67a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 99 additions and 92 deletions

View File

@ -1,5 +1,8 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
jobs:
linux_build_and_test:
docker:
@ -31,5 +34,7 @@ workflows:
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- linux_build_and_test:
requires: [ hold ]

View File

@ -1,4 +1,4 @@
from .convert import convert
from .utils import generate, load
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.0.14"
from .utils import convert, generate, load
from .version import __version__

View File

@ -1,22 +1,8 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import copy
import glob
import json
import shutil
from pathlib import Path
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
from .utils import (
fetch_from_hub,
get_model_path,
linear_class_predicate,
save_weights,
upload_to_hub,
)
from .utils import convert
def configure_parser() -> argparse.ArgumentParser:
@ -59,73 +45,6 @@ def configure_parser() -> argparse.ArgumentParser:
return parser
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
) -> Tuple:
"""
Applies quantization to the model weights.
Args:
model (nn.Module): The model to be quantized.
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
Returns:
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
):
print("[INFO] Loading")
model_path = get_model_path(hf_path)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)
del model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)
if __name__ == "__main__":
parser = configure_parser()
args = parser.parse_args()

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import glob
import json

View File

@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import time

View File

@ -1,9 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
import copy
import gc
import glob
import importlib
import json
import logging
import shutil
import time
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
@ -11,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
# Local imports
@ -515,3 +519,70 @@ def save_weights(
f,
indent=4,
)
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
) -> Tuple:
"""
Applies quantization to the model weights.
Args:
model (nn.Module): The model to be quantized.
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
Returns:
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
):
print("[INFO] Loading")
model_path = get_model_path(hf_path)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)
del model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)

3
llms/mlx_lm/version.py Normal file
View File

@ -0,0 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.0.14"

View File

@ -1,15 +1,18 @@
import sys
from pathlib import Path
import mlx_lm
import pkg_resources
from setuptools import setup
with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
package_dir = Path(__file__).parent / "mlx_lm"
with open(package_dir / "requirements.txt") as fid:
requirements = [l.strip() for l in fid.readlines()]
sys.path.append(str(package_dir))
from version import __version__
setup(
name="mlx-lm",
version=mlx_lm.__version__,
version=__version__,
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",