mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
remerge with dpo
This commit is contained in:
parent
a9b7609118
commit
7d279b51ef
@ -13,72 +13,11 @@ import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
from ..generate import generate
|
||||
|
||||
|
||||
class TrainingCallback:
|
||||
|
||||
def on_train_loss_report(self, train_info: dict):
|
||||
"""Called to report training loss at specified intervals."""
|
||||
pass
|
||||
|
||||
def on_val_loss_report(self, val_info: dict):
|
||||
"""Called to report validation loss at specified intervals or the beginning."""
|
||||
pass
|
||||
|
||||
|
||||
def grad_checkpoint(layer):
|
||||
"""
|
||||
Update all instances of type(layer) to use gradient checkpointing.
|
||||
"""
|
||||
fn = type(layer).__call__
|
||||
|
||||
def checkpointed_fn(model, *args, **kwargs):
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
model.update(params)
|
||||
return fn(model, *args, **kwargs)
|
||||
|
||||
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
type(layer).__call__ = checkpointed_fn
|
||||
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPOTrainingArgs:
|
||||
# Original parameters
|
||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
||||
val_batches: int = field(
|
||||
default=25,
|
||||
metadata={
|
||||
"help": "Number of validation batches, -1 uses the entire validation set."
|
||||
},
|
||||
)
|
||||
steps_per_report: int = field(
|
||||
default=10,
|
||||
metadata={"help": "Number of training steps between loss reporting."},
|
||||
)
|
||||
steps_per_eval: int = field(
|
||||
default=200,
|
||||
metadata={"help": "Number of training steps between validations."}
|
||||
)
|
||||
steps_per_save: int = field(
|
||||
default=100,
|
||||
metadata={"help": "Save the model every number steps"}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum sequence length."}
|
||||
)
|
||||
adapter_file: str = field(
|
||||
default="adapters.safetensors",
|
||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
||||
)
|
||||
grad_checkpoint: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||
)
|
||||
|
||||
# DPO-specific parameters
|
||||
class DPOTrainingArgs(TrainingArgs):
|
||||
beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "Temperature parameter for DPO training."}
|
||||
@ -205,29 +144,6 @@ def dpo_loss(
|
||||
return loss, reward, num_tokens
|
||||
|
||||
|
||||
def compare(
|
||||
tokenizer,
|
||||
model: nn.Module,
|
||||
reference_teacher_model: nn.Module,
|
||||
prompt: str,
|
||||
temperature: float = 0.0,
|
||||
max_tokens: int = 1024
|
||||
):
|
||||
"""
|
||||
Generate comparison between policy and reference model completions.
|
||||
Args:
|
||||
prompt: Prompt to start generation.
|
||||
temperature: Sampling temperature.
|
||||
max_tokens: Max number of tokens to generate.
|
||||
Returns:
|
||||
Completions.
|
||||
"""
|
||||
reference_completion = ''.join([t[0] for t in generate(reference_teacher_model, tokenizer, prompt, temperature==temperature, max_tokens=max_tokens)])
|
||||
policy_completion = ''.join([t[0] for t in generate(model, tokenizer, prompt, temperature=temperature, max_tokens=max_tokens)])
|
||||
|
||||
return reference_completion, policy_completion
|
||||
|
||||
|
||||
def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""
|
||||
Modified iterate_batches for DPO training that handles chosen and rejected samples.
|
||||
|
Loading…
Reference in New Issue
Block a user