From 35ecc17042dd856c2a7a951f4f7b1a86fc6d0e41 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 11 Feb 2025 17:07:08 +0100 Subject: [PATCH 1/2] fix --- llms/mlx_lm/tuner/datasets.py | 2 +- llms/mlx_lm/tuner/grpo_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 5f00d3e3..32522c8d 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -310,7 +310,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_local_dataset(args, data_path, tokenizer, args) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset(args.data, tokenizer, args) + train, valid, test = load_hf_dataset(args, args.data, tokenizer, args) if args.train and len(train) == 0: raise ValueError( diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 4a1e6bbf..da42da2b 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -117,7 +117,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): end_sequence = tokenizer.encode("") end_sequence_length = len(end_sequence) - output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32) + output = mx.zeros((prompt.shape[1] + max_tokens,)) output[:prompt.shape[1]] = prompt[0] current_length = prompt.shape[1] @@ -126,7 +126,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): if temperature > 0: logits /= temperature logprobs = logits - mx.logsumexp(logits, keepdims=True) - return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0] + return mx.random.categorical(logprobs[None, :])[0] for _ in range(max_tokens): current_input = output[:current_length][None, :] From 978deab589cda54ad750df8c28b833ef36354731 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 11 Feb 2025 17:48:42 +0100 Subject: [PATCH 2/2] small fix --- llms/mlx_lm/tuner/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index da42da2b..4a1e6bbf 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -117,7 +117,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): end_sequence = tokenizer.encode("") end_sequence_length = len(end_sequence) - output = mx.zeros((prompt.shape[1] + max_tokens,)) + output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32) output[:prompt.shape[1]] = prompt[0] current_length = prompt.shape[1] @@ -126,7 +126,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): if temperature > 0: logits /= temperature logprobs = logits - mx.logsumexp(logits, keepdims=True) - return mx.random.categorical(logprobs[None, :])[0] + return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0] for _ in range(max_tokens): current_input = output[:current_length][None, :]