# Copyright © 2024 Apple Inc. from functools import partial import glob import shutil import time from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Tuple from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten from transformers import PreTrainedTokenizer from mlx_lm.tokenizer_utils import TokenizerWrapper def grad_checkpoint(layer): """ Update all instances of type(layer) to use gradient checkpointing. """ fn = type(layer).__call__ def checkpointed_fn(model, *args, **kwargs): def inner_fn(params, *args, **kwargs): model.update(params) return fn(model, *args, **kwargs) return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) type(layer).__call__ = checkpointed_fn @dataclass class TrainingArgs: batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) iters: int = field(default=100, metadata={"help": "Iterations to train for."}) val_batches: int = field( default=25, metadata={ "help": "Number of validation batches, -1 uses the entire validation set." }, ) steps_per_report: int = field( default=10, metadata={"help": "Number of training steps between loss reporting."}, ) steps_per_eval: int = field( default=200, metadata={"help": "Number of training steps between validations."} ) steps_per_save: int = field( default=100, metadata={"help": "Save the model every number steps"} ) max_seq_length: int = field( default=2048, metadata={"help": "Maximum sequence length."} ) adapter_file: str = field( default="adapters.safetensors", metadata={"help": "Save/load path for the trained adapter weights."}, ) grad_checkpoint: bool = field( default=False, metadata={"help": "Use gradient checkpointing to reduce memory use."}, ) cot: bool = field( default=False, metadata={"help": "Use CoT loss masking with positioning penalty"}, ) reasoning_token: str = field( default="[REASONING]", metadata={"help": "Reasoning token"}, ) data_token: str = field( default="[DATA]", metadata={"help": "Final answer token"}, ) def default_loss(model, batch, lengths): inputs = batch[:, :-1] targets = batch[:, 1:] logits = model(inputs) logits = logits.astype(mx.float32) steps = mx.arange(1, targets.shape[1] + 1) mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) ce = nn.losses.cross_entropy(logits, targets) * mask ntoks = mask.sum() ce = ce.sum() / ntoks return ce, ntoks def cot_loss( model: nn.Module, inputs: mx.array, targets: mx.array, lengths: int, tokenizer: TokenizerWrapper, args: TrainingArgs, penalty: mx.float32 = 10.0, ) -> tuple[mx.array, mx.array]: logits = model(inputs).astype(mx.float32) reasoning_token_id = tokenizer.encode(args.reasoning_token)[0] data_token_id = tokenizer.encode(args.data_token)[0] reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1) # find the LAST occurrence of data_token_id using slicing (in case generated dataset has multiple occurrences of [DATA]) data_positions = mx.argmax(targets[:, ::-1] == data_token_id, axis=1) data_positions = targets.shape[1] - 1 - data_positions seq_indices = mx.arange(targets.shape[1])[None, :] # base CoT mask: starts at [DATA] cot_mask = (seq_indices >= data_positions[:, None]).astype(mx.float32) # length mask: limits to non-padded regions length_mask = (seq_indices < lengths[:, None]).astype(mx.float32) # combine masks: only include tokens after [DATA] AND within sequence length loss_mask = cot_mask * length_mask # validate sequence structure valid_seq = ( (reasoning_positions < data_positions) & mx.any(targets == reasoning_token_id, axis=1) & mx.any(targets == data_token_id, axis=1) ) # compute base cross-entropy loss ce = nn.losses.cross_entropy(logits, targets) # masking loss before [DATA]; applying penalty for invalid seq valid_loss = (ce * loss_mask).sum(axis=1) / (mx.sum(loss_mask, axis=1) + 1e-8) final_loss = mx.where(valid_seq, valid_loss, penalty) # 10.0 as invalid penalty loss = mx.mean(final_loss) valid_tokens = mx.sum(loss_mask) + 1e-8 return loss, valid_tokens def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size}" f" examples but only has {len(dataset)}." ) # If running in distributed mode (N machines) then each one should skip N-1 # samples step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("The batch size must be divisible by the number of workers") # Make the batches: batch_idx = [ idx[i : i + batch_size : step] for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] if len(batch[0]) == 2: batch, offsets = zip(*batch) else: offsets = [0] * len(batch) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " "Consider pre-splitting your data to save memory." ) # Pad to the nearest multiple of 8 or the maximum length pad_to = 8 max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = min(max_length_in_batch, max_seq_length) batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) for j in range(batch_size // step): truncated_length = min(lengths[j], max_seq_length) batch_arr[j, :truncated_length] = batch[j][:truncated_length] lengths[j] = ( truncated_length # Update lengths to match truncated lengths ) batch = mx.array(batch_arr) yield batch, mx.array(list(zip(offsets, lengths))) if not train: break def evaluate( model, dataset, tokenizer, batch_size, num_batches, max_seq_length=2048, loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): all_losses = mx.array(0.0) ntokens = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) for _, batch in zip( index_iterator, iterate_batches( dataset=dataset, tokenizer=tokenizer, batch_size=batch_size, max_seq_length=max_seq_length, ), ): losses, toks = loss(model, *batch) all_losses += losses * toks ntokens += toks mx.eval(all_losses, ntokens) all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) return (all_losses / ntokens).item() class TrainingCallback: def on_train_loss_report(self, train_info: dict): """Called to report training loss at specified intervals.""" pass def on_val_loss_report(self, val_info: dict): """Called to report validation loss at specified intervals or the beginning.""" pass def train( model, tokenizer, optimizer, train_dataset, val_dataset, args: TrainingArgs = TrainingArgs(), loss: callable = default_loss, iterate_batches: callable = iterate_batches, training_callback: TrainingCallback = None, ): print(f"Starting training..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() rank = world.rank() if world_size > 1: print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) if args.cot: loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0, args=args) else: loss = default_loss state = [model.state, optimizer.state] def step(batch): # Forward and backward pass (lvalue, toks), grad = loss_value_and_grad(model, *batch) # All reduce the gradients if running in distributed mode grad = average_gradients(grad) # Model update optimizer.update(model, grad) return lvalue, toks loss_value_and_grad = nn.value_and_grad(model, loss) losses = 0 n_tokens = 0 steps = 0 trained_tokens = 0 train_time = 0 # Main training loop start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), iterate_batches( dataset=train_dataset, tokenizer=tokenizer, batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, ), ): tic = time.perf_counter() # Report validation loss if needed, the first validation loss # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: tic = time.perf_counter() val_loss = evaluate( model=model, dataset=val_dataset, loss=loss, tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, ) val_time = time.perf_counter() - tic if rank == 0: print( f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s", flush=True, ) if training_callback is not None: val_info = { "iteration": it, "val_loss": val_loss, "val_time": val_time, } training_callback.on_val_loss_report(val_info) tic = time.perf_counter() lvalue, toks = step(batch) losses += lvalue n_tokens += toks steps += 1 mx.eval(state, losses, n_tokens) train_time += time.perf_counter() - tic # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / train_time tokens_sec = float(n_tokens) / train_time trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: print( f"Iter {it}: Train loss {train_loss:.3f}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " f"Trained Tokens {trained_tokens}, " f"Peak mem {peak_mem:.3f} GB", flush=True, ) if training_callback is not None: train_info = { "iteration": it, "train_loss": train_loss, "learning_rate": learning_rate, "iterations_per_second": it_sec, "tokens_per_second": tokens_sec, "trained_tokens": trained_tokens, "peak_memory": peak_mem, } training_callback.on_train_loss_report(train_info) losses = 0 n_tokens = 0 steps = 0 train_time = 0 # Save adapter weights if it % args.steps_per_save == 0: adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) checkpoint = ( Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" ) mx.save_safetensors(str(checkpoint), adapter_weights) print( f"Iter {it}: Saved adapter weights to " f"{args.adapter_file} and {checkpoint}." ) # Save final weights adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) print(f"Saved final weights to {args.adapter_file}.")