mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
feat: move lora into mlx-lm (#337)
* feat: Add lora and qlora training to mlx-lm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Generator, Tuple
|
||||
from typing import Any, Dict, Generator, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import llama, mixtral, phi2, plamo, qwen
|
||||
@@ -21,6 +22,7 @@ MODEL_MAPPING = {
|
||||
"qwen": qwen,
|
||||
"plamo": plamo,
|
||||
}
|
||||
MAX_FILE_SIZE_GB = 15
|
||||
|
||||
linear_class_predicate = (
|
||||
lambda m: isinstance(m, nn.Linear)
|
||||
@@ -204,6 +206,7 @@ def load_model(model_path: Path) -> nn.Module:
|
||||
|
||||
mx.eval(model.parameters())
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
@@ -211,7 +214,7 @@ def load(
|
||||
path_or_hf_repo: str, tokenizer_config={}
|
||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
"""
|
||||
Load the model from a given path or a huggingface repository.
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
|
||||
Args:
|
||||
model_path (Path): The path or the huggingface repository to load the model from.
|
||||
@@ -229,3 +232,103 @@ def load(
|
||||
model = load_model(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def fetch_from_hub(
|
||||
model_path: Path,
|
||||
) -> Tuple[Dict, dict, PreTrainedTokenizer]:
|
||||
model = load_model(model_path)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
return model, config.to_dict(), tokenizer
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
||||
"""
|
||||
Splits the weights into smaller shards.
|
||||
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
||||
|
||||
Returns:
|
||||
list: List of weight shards.
|
||||
"""
|
||||
max_file_size_bytes = max_file_size_gb << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
estimated_size = v.size * v.dtype.size
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
||||
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
"""
|
||||
Uploads the model to Hugging Face hub.
|
||||
|
||||
Args:
|
||||
path (str): Local path to the model.
|
||||
upload_repo (str): Name of the HF repo to upload to.
|
||||
hf_path (str): Path to the original Hugging Face model.
|
||||
"""
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, ModelCard, logging
|
||||
|
||||
card = ModelCard.load(hf_path)
|
||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||
card.text = f"""
|
||||
# {upload_repo}
|
||||
This model was converted to MLX format from [`{hf_path}`]().
|
||||
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
||||
## Use with mlx
|
||||
|
||||
```bash
|
||||
pip install mlx-lm
|
||||
```
|
||||
|
||||
```python
|
||||
from mlx_lm import load, generate
|
||||
|
||||
model, tokenizer = load("{upload_repo}")
|
||||
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
||||
```
|
||||
"""
|
||||
card.save(os.path.join(path, "README.md"))
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=path,
|
||||
repo_id=upload_repo,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||
"""Save model weights into specified directory."""
|
||||
if isinstance(save_path, str):
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shards = make_shards(weights)
|
||||
shards_count = len(shards)
|
||||
shard_file_format = (
|
||||
"model-{:05d}-of-{:05d}.safetensors"
|
||||
if shards_count > 1
|
||||
else "model.safetensors"
|
||||
)
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||
mx.save_safetensors(str(save_path / shard_name), shard)
|
||||
|
||||
Reference in New Issue
Block a user