mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Support for OpenAI’s fine-tuning dataset format (#548)
* LoRA: move load_dataset to tuner/datasets.py file * LoRA: support OpenAI chat format datasets see https://platform.openai.com/docs/guides/fine-tuning/example-format * LoRA: support OpenAI completion format datasets * LoRA: formatting dataset timing to reduce memory footprint * Refactor dataset item access in PromptCompletionDataset * Update mlx_lm/LORA.md * Update mlx_lm/LORA.md * check Unsupported data format * add tests, fine-tune doc * add tests, fine-tune doc * add jinja2 for chat template * nits in readme * nits in readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -12,6 +12,7 @@ import numpy as np
|
||||
import yaml
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import linear_to_lora_layers
|
||||
from .utils import load
|
||||
@@ -141,46 +142,6 @@ def build_parser():
|
||||
return parser
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold lines from a jsonl file
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, key: str = "text"):
|
||||
if not path.exists():
|
||||
self._data = None
|
||||
else:
|
||||
with open(path, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
self._key = key
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx][self._key]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def load_dataset(args):
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||
)
|
||||
if args.train and len(valid) == 0:
|
||||
raise ValueError(
|
||||
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||
)
|
||||
if args.test and len(test) == 0:
|
||||
raise ValueError(
|
||||
"Test set not found or empty. Must provide test set for evaluation."
|
||||
)
|
||||
return train, valid, test
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
trainable_p = (
|
||||
@@ -206,7 +167,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
print_trainable_parameters(model)
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args)
|
||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||
|
||||
# Resume training the given adapters.
|
||||
if args.resume_adapter_file is not None:
|
||||
|
||||
Reference in New Issue
Block a user