diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 7d1821d4..392030cb 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -166,7 +166,7 @@ def iterate_delineated_batches( for j in batch_idx[i]: prompt, completion = dataset.get_prompt_and_completion(j) prompt_lengths.append(input_length(prompt, completion, tokenizer)) - full_sequence = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] + full_sequence = tokenizer.encode(dataset[j]) if full_sequence[-1] != tokenizer.eos_token_id: full_sequence.append(tokenizer.eos_token_id) batch.append(full_sequence)