diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index b31656c6..23e293c4 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -131,6 +131,7 @@ class CompletionsDataset: def create_dataset( + args, data, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -143,7 +144,14 @@ def create_dataset( if "messages" in sample: return ChatDataset(data, tokenizer) elif "prompt" in sample and "answer" in sample: - return GRPODataset(data, tokenizer, "prompt", "answer") # Use GRPO Dataset + return GRPODataset( + data=data, + tokenizer=tokenizer, + prompt_key="prompt", + answer_key="answer", + use_chat_template=args.use_chat_template, + use_prompt=args.use_prompt + ) elif prompt_feature in sample and completion_feature in sample: return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) elif "text" in sample: @@ -156,6 +164,7 @@ def create_dataset( def load_local_dataset( + args, data_path: Path, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -166,7 +175,7 @@ def load_local_dataset( return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(args, data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -174,6 +183,7 @@ def load_local_dataset( def load_hf_dataset( + args, data_id: str, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -189,7 +199,7 @@ def load_hf_dataset( train, valid, test = [ ( create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature + args, dataset[n], tokenizer, prompt_feature, completion_feature ) if n in dataset.keys() else [] @@ -254,12 +264,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature + args, data_path, tokenizer, prompt_feature, completion_feature ) else: print(f"Loading Hugging Face dataset {args.data}.") train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature + args, args.data, tokenizer, prompt_feature, completion_feature ) if args.train and len(train) == 0: