diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 153bc49d..bbb574b4 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -235,6 +235,10 @@ def train_model( build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate ) ) + + if args.mask_inputs: + print("Masking inputs..") + # Train model train( model=model, diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index d74fddbb..9a251011 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -168,7 +168,7 @@ def iterate_delineated_batches( for j in batch_idx[i]: prompt, completion = dataset.get_prompt_and_completion(j) prompt_length, completion_length = input_and_output_lengths( - prompt, prompt, tokenizer + prompt, completion, tokenizer ) prompt_lengths.append(prompt_length)