# Copyright © 2024 Apple Inc. import glob import shutil import time from dataclasses import dataclass, field from pathlib import Path from typing import List, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten from transformers import PreTrainedTokenizer from .datasets import CompletionsDataset def grad_checkpoint(layer): """ Update all instances of type(layer) to use gradient checkpointing. """ fn = type(layer).__call__ def checkpointed_fn(model, *args, **kwargs): def inner_fn(params, *args, **kwargs): model.update(params) return fn(model, *args, **kwargs) return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) type(layer).__call__ = checkpointed_fn @dataclass class TrainingArgs: batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) iters: int = field(default=100, metadata={"help": "Iterations to train for."}) val_batches: int = field( default=25, metadata={ "help": "Number of validation batches, -1 uses the entire validation set." }, ) steps_per_report: int = field( default=10, metadata={"help": "Number of training steps between loss reporting."}, ) steps_per_eval: int = field( default=200, metadata={"help": "Number of training steps between validations."} ) steps_per_save: int = field( default=100, metadata={"help": "Save the model every number steps"} ) max_seq_length: int = field( default=2048, metadata={"help": "Maximum sequence length."} ) adapter_file: str = field( default="adapters.safetensors", metadata={"help": "Save/load path for the trained adapter weights."}, ) grad_checkpoint: bool = field( default=False, metadata={"help": "Use gradient checkpointing to reduce memory use."}, ) def input_masked_loss(model, inputs, input_lengths, lengths): shifted_inputs = inputs[:, :-1] shifted_labels = inputs[:, 1:] logits = model(shifted_inputs) logits = logits.astype(mx.float32) mask_width = shifted_inputs.shape[1] token_indices = mx.arange(mask_width)[None, :] mask = mx.logical_and( token_indices >= input_lengths[:, None], token_indices < lengths[:, None] ) ce = nn.losses.cross_entropy(logits, shifted_labels) * mask ntoks = mask.sum() ce = ce.sum() / ntoks return ce, ntoks def default_loss(model, inputs, targets, lengths): logits = model(inputs) logits = logits.astype(mx.float32) length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] ce = nn.losses.cross_entropy(logits, targets) * length_mask ntoks = length_mask.sum() ce = ce.sum() / ntoks return ce, ntoks def contains(small_list: List, big_list: List) -> Tuple[int, int]: """ Returns the beginning and end index of the first occurrence of small_list in big_list. """ small_list_length = len(small_list) for ind in (i for i, e in enumerate(big_list) if e == small_list[0]): if big_list[ind : ind + small_list_length] == small_list: return ind, ind + small_list_length - 1 def no_bos(sequence: List, bos: int) -> List: return sequence if sequence[0] != bos else sequence[1:] def input_length( input_text: str, output_text: str, tokenizer: PreTrainedTokenizer ) -> int: """ Returns the length of the portion of the encoding of the concatenation of input_text and output_text that corresponds to the input tokens. """ message = [ {"role": "user", "content": input_text}, {"role": "assistant", "content": output_text}, ] output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id) full_sequence = tokenizer.apply_chat_template( message, add_generation_prompt=True, tokenize=True ) output_begin, output_end = contains(output_tokens, full_sequence) return output_begin def iterate_delineated_batches( dataset: CompletionsDataset, tokenizer: PreTrainedTokenizer, batch_size: int, max_seq_length: int, train: bool = False, ): """ A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens and returns the lengths of input tokens as well as that of the full sequences. """ idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i])) if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size}" f" examples but only has {len(dataset)}." ) # If running in distributed mode (N machines) then each one should skip N-1 # samples step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("The batch size must be divisible by the number of workers") # Make the batches: batch_idx = [ idx[i : i + batch_size : step] for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: indices = np.random.permutation(len(batch_idx)) for i in indices: prompt_lengths = [] batch = [] for j in batch_idx[i]: prompt, completion = dataset.get_prompt_and_completion(j) prompt_lengths.append(input_length(prompt, completion, tokenizer)) full_sequence = tokenizer.encode(dataset[j]) if full_sequence[-1] != tokenizer.eos_token_id: full_sequence.append(tokenizer.eos_token_id) batch.append(full_sequence) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " "Consider pre-splitting your data to save memory." ) # Pad to the nearest multiple of 8 or the maximum length pad_to = 8 max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = min(max_length_in_batch, max_seq_length) batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) for j in range(batch_size // step): truncated_length = min(lengths[j], max_seq_length) batch_arr[j, :truncated_length] = batch[j][:truncated_length] lengths[j] = ( truncated_length # Update lengths to match truncated lengths ) yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths) if not train: break def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size}" f" examples but only has {len(dataset)}." ) # If running in distributed mode (N machines) then each one should skip N-1 # samples step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("The batch size must be divisible by the number of workers") # Make the batches: batch_idx = [ idx[i : i + batch_size : step] for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " "Consider pre-splitting your data to save memory." ) # Pad to the nearest multiple of 8 or the maximum length pad_to = 8 max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = min(max_length_in_batch, max_seq_length) batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) for j in range(batch_size // step): truncated_length = min(lengths[j], max_seq_length) batch_arr[j, :truncated_length] = batch[j][:truncated_length] lengths[j] = ( truncated_length # Update lengths to match truncated lengths ) batch = mx.array(batch_arr) yield batch[:, :-1], batch[:, 1:], mx.array(lengths) if not train: break def evaluate( model, dataset, tokenizer, batch_size, num_batches, max_seq_length=2048, loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): all_losses = mx.array(0.0) ntokens = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) for _, batch in zip( index_iterator, iterate_batches( dataset=dataset, tokenizer=tokenizer, batch_size=batch_size, max_seq_length=max_seq_length, ), ): losses, toks = loss(model, *batch) all_losses += losses * toks ntokens += toks mx.eval(all_losses, ntokens) all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) return (all_losses / ntokens).item() class TrainingCallback: def on_train_loss_report(self, train_info: dict): """Called to report training loss at specified intervals.""" pass def on_val_loss_report(self, val_info: dict): """Called to report validation loss at specified intervals or the beginning.""" pass def train( model, tokenizer, optimizer, train_dataset, val_dataset, args: TrainingArgs = TrainingArgs(), loss: callable = default_loss, iterate_batches: callable = iterate_batches, training_callback: TrainingCallback = None, ): print(f"Starting training..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() rank = world.rank() if world_size > 1: print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) state = [model.state, optimizer.state] def step(batch): # Forward and backward pass (lvalue, toks), grad = loss_value_and_grad(model, *batch) # All reduce the gradients if running in distributed mode grad = average_gradients(grad) # Model update optimizer.update(model, grad) return lvalue, toks loss_value_and_grad = nn.value_and_grad(model, loss) losses = 0 n_tokens = 0 steps = 0 trained_tokens = 0 # Main training loop start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), iterate_batches( dataset=train_dataset, tokenizer=tokenizer, batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, ), ): # Report validation loss if needed, the first validation loss # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() val_loss = evaluate( model=model, dataset=val_dataset, loss=loss, tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop if rank == 0: print( f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s", flush=True, ) if training_callback is not None: val_info = { "iteration": it, "val_loss": val_loss, "val_time": val_time, } training_callback.on_val_loss_report(val_info) start = time.perf_counter() lvalue, toks = step(batch) losses += lvalue n_tokens += toks steps += 1 mx.eval(state, losses, n_tokens) # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: print( f"Iter {it}: Train loss {train_loss:.3f}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " f"Trained Tokens {trained_tokens}, " f"Peak mem {peak_mem:.3f} GB", flush=True, ) if training_callback is not None: train_info = { "iteration": it, "train_loss": train_loss, "learning_rate": learning_rate, "iterations_per_second": it_sec, "tokens_per_second": tokens_sec, "trained_tokens": trained_tokens, "peak_memory": peak_mem, } training_callback.on_train_loss_report(train_info) losses = 0 n_tokens = 0 steps = 0 start = time.perf_counter() # Save adapter weights if it % args.steps_per_save == 0: adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) checkpoint = ( Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" ) mx.save_safetensors(str(checkpoint), adapter_weights) print( f"Iter {it}: Saved adapter weights to " f"{args.adapter_file} and {checkpoint}." ) # Save final weights adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) print(f"Saved final weights to {args.adapter_file}.")