diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index 01ffd81b..df5c4588 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -34,6 +34,7 @@ class GRPODataset:
The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
{'role': 'user', 'content': prompt_str}
],
+ add_generation_prompt=True
)
answer_tokens = tokenizer.encode(answer_str)
else:
@@ -307,10 +308,10 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
else:
data_path = Path(args.data)
if data_path.exists():
- train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
+ train, valid, test = load_local_dataset(args, data_path, tokenizer, args.config)
else:
print(f"Loading Hugging Face dataset {args.data}.")
- train, valid, test = load_hf_dataset(args, args.data, tokenizer, args)
+ train, valid, test = load_hf_dataset(args, args.data, tokenizer, args.config)
if args.train and len(train) == 0:
raise ValueError(