diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 6edea28d..eb0a279e 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -79,6 +79,7 @@ def build_parser(): "--train", action="store_true", help="Do training", + default=None, ) parser.add_argument( "--data", @@ -94,6 +95,12 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) + parser.add_argument( + "--mask-prompt", + action="store_true", + help="Mask the prompt in the loss when training", + default=False, + ) parser.add_argument( "--num-layers", type=int, @@ -136,6 +143,7 @@ def build_parser(): "--test", action="store_true", help="Evaluate on the test set after training", + default=None, ) parser.add_argument( "--test-batches", @@ -157,6 +165,7 @@ def build_parser(): "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", + default=None, ) parser.add_argument("--seed", type=int, help="The PRNG seed") parser.add_argument( @@ -176,6 +185,11 @@ def train_model( training_callback: TrainingCallback = None, ): model.freeze() + if args.num_layers > len(model.layers): + raise ValueError( + f"Requested to train {args.num_layers} layers " + f"but the model only has {len(model.layers)} layers." + ) if args.fine_tune_type == "full": for l in model.layers[-max(args.num_layers, 0) :]: l.unfreeze() diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 4f5db2ea..24e93f92 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -6,6 +6,7 @@ import shutil import time from dataclasses import dataclass, field from pathlib import Path +from typing import List, Optional, Tuple from typing import Union import mlx.core as mx @@ -13,7 +14,7 @@ import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten - +from transformers import PreTrainedTokenizer from mlx_lm.tokenizer_utils import TokenizerWrapper @@ -70,14 +71,18 @@ class TrainingArgs: ) -def default_loss(model, inputs, targets, lengths): +def default_loss(model, batch, lengths): + inputs = batch[:, :-1] + targets = batch[:, 1:] + logits = model(inputs) logits = logits.astype(mx.float32) - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + steps = mx.arange(1, targets.shape[1] + 1) + mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() + ce = nn.losses.cross_entropy(logits, targets) * mask + ntoks = mask.sum() ce = ce.sum() / ntoks return ce, ntoks @@ -162,6 +167,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] + if len(batch[0]) == 2: + batch, offsets = zip(*batch) + else: + offsets = [0] * len(batch) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( @@ -185,7 +194,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) ) batch = mx.array(batch_arr) - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + yield batch, mx.array(list(zip(offsets, lengths))) if not train: break @@ -201,8 +210,8 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = 0 - ntokens = 0 + all_losses = mx.array(0.0) + ntokens = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -283,6 +292,7 @@ def train( n_tokens = 0 steps = 0 trained_tokens = 0 + train_time = 0 # Main training loop start = time.perf_counter() for it, batch in zip( @@ -295,10 +305,11 @@ def train( train=True, ), ): + tic = time.perf_counter() # Report validation loss if needed, the first validation loss # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - stop = time.perf_counter() + tic = time.perf_counter() val_loss = evaluate( model=model, dataset=val_dataset, @@ -309,7 +320,7 @@ def train( max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, ) - val_time = time.perf_counter() - stop + val_time = time.perf_counter() - tic if rank == 0: print( f"Iter {it}: " @@ -326,24 +337,23 @@ def train( } training_callback.on_val_loss_report(val_info) - start = time.perf_counter() + tic = time.perf_counter() lvalue, toks = step(batch) losses += lvalue n_tokens += toks steps += 1 mx.eval(state, losses, n_tokens) + train_time += time.perf_counter() - tic # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: - stop = time.perf_counter() - train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / (stop - start) - tokens_sec = float(n_tokens) / (stop - start) + it_sec = args.steps_per_report / train_time + tokens_sec = float(n_tokens) / train_time trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: @@ -372,7 +382,7 @@ def train( losses = 0 n_tokens = 0 steps = 0 - start = time.perf_counter() + train_time = 0 # Save adapter weights if it % args.steps_per_save == 0: