From e477060a00e7d2fdc5f6d44f629a6fd17eef684b Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 20:26:15 -0500 Subject: [PATCH] Fix keyword argument invokation --- llms/mlx_lm/tuner/datasets.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 3b442c6a..c75171e5 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -194,10 +194,20 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): valid_split = hf_args.get("valid_split", "train[-10%:]") text_f, prompt_f, completion_f, chat_f = get_hf_custom_features(hf_args) train = create_hf_dataset( - ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split + dataset_name=ds_name, + text_feature=text_f, + prompt_feature=prompt_f, + completion_feature=completion_f, + chat_feature=chat_f, + split=train_split, ) valid = create_hf_dataset( - ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split + dataset_name=ds_name, + text_feature=text_f, + prompt_feature=prompt_f, + completion_feature=completion_f, + chat_feature=chat_f, + split=valid_split, ) return train, valid @@ -219,11 +229,11 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): train, valid = [], [] if args.test: test = create_hf_dataset( - dataset_name, - text_feature, - prompt_feature, - completion_feature, - chat_f, + dataset_name=dataset_name, + text_feature=text_feature, + prompt_feature=prompt_feature, + completion_feature=completion_feature, + chat_feature=chat_f, split=hf_args.get("test_split"), ) else: