From 30fd5af843e2600165b5dd2318e33f53d8f0f9f6 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Wed, 6 Nov 2024 12:33:49 -0500 Subject: [PATCH] Fix variable reference --- llms/mlx_lm/lora.py | 4 ++++ llms/mlx_lm/tuner/trainer.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 153bc49d..bbb574b4 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -235,6 +235,10 @@ def train_model( build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate ) ) + + if args.mask_inputs: + print("Masking inputs..") + # Train model train( model=model, diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index d74fddbb..9a251011 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -168,7 +168,7 @@ def iterate_delineated_batches( for j in batch_idx[i]: prompt, completion = dataset.get_prompt_and_completion(j) prompt_length, completion_length = input_and_output_lengths( - prompt, prompt, tokenizer + prompt, completion, tokenizer ) prompt_lengths.append(prompt_length)