feat: move lora into mlx-lm (#337)

* feat: Add lora and qlora training to mlx-lm


---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen
2024-01-23 08:44:37 -08:00
committed by GitHub
parent 85c1ff8fd6
commit 362e88a744
13 changed files with 987 additions and 111 deletions

View File

@@ -2,17 +2,21 @@ import argparse
import copy
import glob
import json
import shutil
from pathlib import Path
from typing import Dict, Tuple
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
import transformers
from mlx.utils import tree_flatten
from .utils import get_model_path, linear_class_predicate, load_model
MAX_FILE_SIZE_GB = 15
from .utils import (
fetch_from_hub,
get_model_path,
linear_class_predicate,
save_weights,
upload_to_hub,
)
def configure_parser() -> argparse.ArgumentParser:
@@ -55,22 +59,9 @@ def configure_parser() -> argparse.ArgumentParser:
return parser
def fetch_from_hub(
model_path: str,
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
model_path = get_model_path(model_path)
model = load_model(model_path)
config = transformers.AutoConfig.from_pretrained(model_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
return model, config.to_dict(), tokenizer
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
) -> tuple:
) -> Tuple:
"""
Applies quantization to the model weights.
@@ -81,7 +72,7 @@ def quantize_model(
q_bits (int): Bits per weight for quantization.
Returns:
tuple: Tuple containing quantized weights and config.
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
@@ -94,76 +85,6 @@ def quantize_model(
return quantized_weights, quantized_config
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
"""
Splits the weights into smaller shards.
Args:
weights (dict): Model weights.
max_file_size_gb (int): Maximum size of each shard in gigabytes.
Returns:
list: List of weight shards.
"""
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {upload_repo}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install mlx-lm
```
```python
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
response = generate(model, tokenizer, prompt="hello", verbose=True)
```
"""
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
@@ -174,7 +95,8 @@ def convert(
upload_repo: str = None,
):
print("[INFO] Loading")
model, config, tokenizer = fetch_from_hub(hf_path)
model_path = get_model_path(hf_path)
model, config, tokenizer = fetch_from_hub(model_path)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
@@ -185,12 +107,17 @@ def convert(
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
mlx_path = Path(mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
for i, shard in enumerate(shards):
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)
save_weights(mlx_path, weights)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)