diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index ab410373..d58cfb8d 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -13,72 +13,11 @@ import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten from ..generate import generate - - -class TrainingCallback: - - def on_train_loss_report(self, train_info: dict): - """Called to report training loss at specified intervals.""" - pass - - def on_val_loss_report(self, val_info: dict): - """Called to report validation loss at specified intervals or the beginning.""" - pass - - -def grad_checkpoint(layer): - """ - Update all instances of type(layer) to use gradient checkpointing. - """ - fn = type(layer).__call__ - - def checkpointed_fn(model, *args, **kwargs): - def inner_fn(params, *args, **kwargs): - model.update(params) - return fn(model, *args, **kwargs) - - return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) - - type(layer).__call__ = checkpointed_fn +from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs @dataclass -class DPOTrainingArgs: - # Original parameters - batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) - iters: int = field(default=100, metadata={"help": "Iterations to train for."}) - val_batches: int = field( - default=25, - metadata={ - "help": "Number of validation batches, -1 uses the entire validation set." - }, - ) - steps_per_report: int = field( - default=10, - metadata={"help": "Number of training steps between loss reporting."}, - ) - steps_per_eval: int = field( - default=200, - metadata={"help": "Number of training steps between validations."} - ) - steps_per_save: int = field( - default=100, - metadata={"help": "Save the model every number steps"} - ) - max_seq_length: int = field( - default=2048, - metadata={"help": "Maximum sequence length."} - ) - adapter_file: str = field( - default="adapters.safetensors", - metadata={"help": "Save/load path for the trained adapter weights."}, - ) - grad_checkpoint: bool = field( - default=False, - metadata={"help": "Use gradient checkpointing to reduce memory use."}, - ) - - # DPO-specific parameters +class DPOTrainingArgs(TrainingArgs): beta: float = field( default=0.1, metadata={"help": "Temperature parameter for DPO training."} @@ -205,29 +144,6 @@ def dpo_loss( return loss, reward, num_tokens -def compare( - tokenizer, - model: nn.Module, - reference_teacher_model: nn.Module, - prompt: str, - temperature: float = 0.0, - max_tokens: int = 1024 -): - """ - Generate comparison between policy and reference model completions. - Args: - prompt: Prompt to start generation. - temperature: Sampling temperature. - max_tokens: Max number of tokens to generate. - Returns: - Completions. - """ - reference_completion = ''.join([t[0] for t in generate(reference_teacher_model, tokenizer, prompt, temperature==temperature, max_tokens=max_tokens)]) - policy_completion = ''.join([t[0] for t in generate(model, tokenizer, prompt, temperature=temperature, max_tokens=max_tokens)]) - - return reference_completion, policy_completion - - def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): """ Modified iterate_batches for DPO training that handles chosen and rejected samples.