mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
remerge with dpo
This commit is contained in:
parent
a9b7609118
commit
7d279b51ef
@ -13,72 +13,11 @@ import numpy as np
|
|||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from ..generate import generate
|
from ..generate import generate
|
||||||
|
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
|
||||||
|
|
||||||
class TrainingCallback:
|
|
||||||
|
|
||||||
def on_train_loss_report(self, train_info: dict):
|
|
||||||
"""Called to report training loss at specified intervals."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_val_loss_report(self, val_info: dict):
|
|
||||||
"""Called to report validation loss at specified intervals or the beginning."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def grad_checkpoint(layer):
|
|
||||||
"""
|
|
||||||
Update all instances of type(layer) to use gradient checkpointing.
|
|
||||||
"""
|
|
||||||
fn = type(layer).__call__
|
|
||||||
|
|
||||||
def checkpointed_fn(model, *args, **kwargs):
|
|
||||||
def inner_fn(params, *args, **kwargs):
|
|
||||||
model.update(params)
|
|
||||||
return fn(model, *args, **kwargs)
|
|
||||||
|
|
||||||
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
|
|
||||||
|
|
||||||
type(layer).__call__ = checkpointed_fn
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DPOTrainingArgs:
|
class DPOTrainingArgs(TrainingArgs):
|
||||||
# Original parameters
|
|
||||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
|
||||||
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
|
||||||
val_batches: int = field(
|
|
||||||
default=25,
|
|
||||||
metadata={
|
|
||||||
"help": "Number of validation batches, -1 uses the entire validation set."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
steps_per_report: int = field(
|
|
||||||
default=10,
|
|
||||||
metadata={"help": "Number of training steps between loss reporting."},
|
|
||||||
)
|
|
||||||
steps_per_eval: int = field(
|
|
||||||
default=200,
|
|
||||||
metadata={"help": "Number of training steps between validations."}
|
|
||||||
)
|
|
||||||
steps_per_save: int = field(
|
|
||||||
default=100,
|
|
||||||
metadata={"help": "Save the model every number steps"}
|
|
||||||
)
|
|
||||||
max_seq_length: int = field(
|
|
||||||
default=2048,
|
|
||||||
metadata={"help": "Maximum sequence length."}
|
|
||||||
)
|
|
||||||
adapter_file: str = field(
|
|
||||||
default="adapters.safetensors",
|
|
||||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
|
||||||
)
|
|
||||||
grad_checkpoint: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# DPO-specific parameters
|
|
||||||
beta: float = field(
|
beta: float = field(
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "Temperature parameter for DPO training."}
|
metadata={"help": "Temperature parameter for DPO training."}
|
||||||
@ -205,29 +144,6 @@ def dpo_loss(
|
|||||||
return loss, reward, num_tokens
|
return loss, reward, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def compare(
|
|
||||||
tokenizer,
|
|
||||||
model: nn.Module,
|
|
||||||
reference_teacher_model: nn.Module,
|
|
||||||
prompt: str,
|
|
||||||
temperature: float = 0.0,
|
|
||||||
max_tokens: int = 1024
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generate comparison between policy and reference model completions.
|
|
||||||
Args:
|
|
||||||
prompt: Prompt to start generation.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
max_tokens: Max number of tokens to generate.
|
|
||||||
Returns:
|
|
||||||
Completions.
|
|
||||||
"""
|
|
||||||
reference_completion = ''.join([t[0] for t in generate(reference_teacher_model, tokenizer, prompt, temperature==temperature, max_tokens=max_tokens)])
|
|
||||||
policy_completion = ''.join([t[0] for t in generate(model, tokenizer, prompt, temperature=temperature, max_tokens=max_tokens)])
|
|
||||||
|
|
||||||
return reference_completion, policy_completion
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
"""
|
"""
|
||||||
Modified iterate_batches for DPO training that handles chosen and rejected samples.
|
Modified iterate_batches for DPO training that handles chosen and rejected samples.
|
||||||
|
Loading…
Reference in New Issue
Block a user