remerge with dpo

This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 01:14:08 +01:00
parent a9b7609118
commit 7d279b51ef

View File

@ -13,72 +13,11 @@ import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from ..generate import generate
class TrainingCallback:
def on_train_loss_report(self, train_info: dict):
"""Called to report training loss at specified intervals."""
pass
def on_val_loss_report(self, val_info: dict):
"""Called to report validation loss at specified intervals or the beginning."""
pass
def grad_checkpoint(layer):
"""
Update all instances of type(layer) to use gradient checkpointing.
"""
fn = type(layer).__call__
def checkpointed_fn(model, *args, **kwargs):
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(model, *args, **kwargs)
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
type(layer).__call__ = checkpointed_fn
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
@dataclass
class DPOTrainingArgs:
# Original parameters
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
val_batches: int = field(
default=25,
metadata={
"help": "Number of validation batches, -1 uses the entire validation set."
},
)
steps_per_report: int = field(
default=10,
metadata={"help": "Number of training steps between loss reporting."},
)
steps_per_eval: int = field(
default=200,
metadata={"help": "Number of training steps between validations."}
)
steps_per_save: int = field(
default=100,
metadata={"help": "Save the model every number steps"}
)
max_seq_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length."}
)
adapter_file: str = field(
default="adapters.safetensors",
metadata={"help": "Save/load path for the trained adapter weights."},
)
grad_checkpoint: bool = field(
default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."},
)
# DPO-specific parameters
class DPOTrainingArgs(TrainingArgs):
beta: float = field(
default=0.1,
metadata={"help": "Temperature parameter for DPO training."}
@ -205,29 +144,6 @@ def dpo_loss(
return loss, reward, num_tokens
def compare(
tokenizer,
model: nn.Module,
reference_teacher_model: nn.Module,
prompt: str,
temperature: float = 0.0,
max_tokens: int = 1024
):
"""
Generate comparison between policy and reference model completions.
Args:
prompt: Prompt to start generation.
temperature: Sampling temperature.
max_tokens: Max number of tokens to generate.
Returns:
Completions.
"""
reference_completion = ''.join([t[0] for t in generate(reference_teacher_model, tokenizer, prompt, temperature==temperature, max_tokens=max_tokens)])
policy_completion = ''.join([t[0] for t in generate(model, tokenizer, prompt, temperature=temperature, max_tokens=max_tokens)])
return reference_completion, policy_completion
def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Modified iterate_batches for DPO training that handles chosen and rejected samples.