mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
added cot loss masking training
This commit is contained in:
parent
09b641aaa7
commit
68403f5577
@ -64,6 +64,12 @@ lora_parameters:
|
||||
scale: 20.0
|
||||
dropout: 0.0
|
||||
|
||||
# cot loss masking training
|
||||
# cot:
|
||||
# use_cot: true
|
||||
# special: true
|
||||
# additional_tokens: ["[REASONING]", "[DATA]"]
|
||||
|
||||
# Schedule can only be specified in a config file, uncomment to use.
|
||||
#lr_schedule:
|
||||
# name: cosine_decay
|
||||
|
@ -62,6 +62,7 @@ CONFIG_DEFAULTS = {
|
||||
"grad_checkpoint": False,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"cot": False,
|
||||
}
|
||||
|
||||
|
||||
@ -78,7 +79,6 @@ def build_parser():
|
||||
"--train",
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
@ -94,14 +94,6 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-prompt",
|
||||
action="store_true",
|
||||
help="Mask the prompt in the loss when training",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
@ -144,7 +136,6 @@ def build_parser():
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batches",
|
||||
@ -166,9 +157,13 @@ def build_parser():
|
||||
"--grad-checkpoint",
|
||||
action="store_true",
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
parser.add_argument(
|
||||
"--cot",
|
||||
type=bool,
|
||||
help="Use CoT loss masking",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -181,14 +176,8 @@ def train_model(
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
model.freeze()
|
||||
if args.num_layers > len(model.layers):
|
||||
raise ValueError(
|
||||
f"Requested to train {args.num_layers} layers "
|
||||
f"but the model only has {len(model.layers)} layers."
|
||||
)
|
||||
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-max(args.num_layers, 0) :]:
|
||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||
l.unfreeze()
|
||||
elif args.fine_tune_type in ["lora", "dora"]:
|
||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||
@ -225,10 +214,13 @@ def train_model(
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
cot=(cot := args.cot),
|
||||
)
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(
|
||||
# todo optimizer from args
|
||||
|
||||
opt = optim.AdamW(
|
||||
learning_rate=(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
@ -269,6 +261,21 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
if cot := args.cot:
|
||||
print("Using CoT loss masking")
|
||||
if tokens := cot.get("additional_tokens"):
|
||||
from .tuner.new_tokens import implement_new_tokens
|
||||
|
||||
special = False
|
||||
if (special_arg := cot.get("special")) and isinstance(special_arg, bool):
|
||||
print("Updating model and tokenizer with new special tokens")
|
||||
special = special_arg
|
||||
else:
|
||||
print("Updating model and tokenizer with new tokens")
|
||||
model, tokenizer = implement_new_tokens(
|
||||
model=model, tokenizer=tokenizer, tokens=tokens, special=special
|
||||
)
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||
|
||||
@ -293,6 +300,7 @@ def main():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
config = args.config
|
||||
|
||||
args = vars(args)
|
||||
if config:
|
||||
print("Loading configuration file", config)
|
||||
|
162
llms/mlx_lm/tuner/new_tokens.py
Normal file
162
llms/mlx_lm/tuner/new_tokens.py
Normal file
@ -0,0 +1,162 @@
|
||||
import mlx.nn as nn
|
||||
import mlx.core as mx
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
|
||||
def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module:
|
||||
"""
|
||||
Resizes model embeddings to accommodate new tokens
|
||||
"""
|
||||
old_embedding = model.model.embed_tokens
|
||||
|
||||
old_vocab_size = old_embedding.num_embeddings
|
||||
new_vocab_size = len(tokenizer._tokenizer)
|
||||
|
||||
if old_vocab_size != new_vocab_size:
|
||||
if new_vocab_size < old_vocab_size:
|
||||
print(
|
||||
"Warning: New vocab size is smaller than original. Proceeding with trim."
|
||||
)
|
||||
|
||||
# check if QuantizedEmbedding has required attributes for dequantization
|
||||
try:
|
||||
dequantized_weights = mx.dequantize(
|
||||
old_embedding.weight,
|
||||
scales=old_embedding.scales,
|
||||
biases=old_embedding.biases,
|
||||
group_size=old_embedding.group_size,
|
||||
bits=old_embedding.bits,
|
||||
)
|
||||
except AttributeError as e:
|
||||
print(f"Error: Cannot dequantize embed_tokens. Missing attributes: {e}")
|
||||
print("Falling back to random weights for embed_tokens.")
|
||||
dequantized_weights = mx.random.normal(
|
||||
(old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02
|
||||
)
|
||||
|
||||
# resize embed_tokens
|
||||
new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims)
|
||||
new_weights = mx.zeros((new_vocab_size, old_embedding.dims))
|
||||
min_vocab_size = min(old_vocab_size, new_vocab_size)
|
||||
new_weights[:min_vocab_size] = dequantized_weights[:min_vocab_size]
|
||||
if new_vocab_size > old_vocab_size:
|
||||
new_weights[old_vocab_size:] = mx.random.normal(
|
||||
(new_vocab_size - old_vocab_size, old_embedding.dims),
|
||||
loc=0.0,
|
||||
scale=0.02,
|
||||
)
|
||||
new_embedding.weight = new_weights
|
||||
model.model.embed_tokens = new_embedding
|
||||
|
||||
# attention layers handling
|
||||
if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False):
|
||||
model.model.embed_tokens.weight = new_weights
|
||||
elif hasattr(model, "lm_head"):
|
||||
old_lm_head = model.lm_head
|
||||
if isinstance(old_lm_head, nn.QuantizedLinear):
|
||||
# resize nn.QuantizedLinear
|
||||
output_dims, compressed_input_dims = old_lm_head.weight.shape
|
||||
bits = old_lm_head.bits
|
||||
input_dims = compressed_input_dims * (32 // bits)
|
||||
|
||||
# dequantize lm_head weights
|
||||
try:
|
||||
dequantized_lm_weights = mx.dequantize(
|
||||
old_lm_head.weight,
|
||||
scales=old_lm_head.scales,
|
||||
biases=old_lm_head.biases,
|
||||
group_size=old_lm_head.group_size,
|
||||
bits=old_lm_head.bits,
|
||||
)
|
||||
except AttributeError as e:
|
||||
print(f"Error: Cannot dequantize lm_head. Missing attributes: {e}")
|
||||
print("Falling back to random weights for lm_head.")
|
||||
dequantized_lm_weights = mx.random.normal(
|
||||
(output_dims, input_dims), loc=0.0, scale=0.02
|
||||
)
|
||||
|
||||
new_lm_head = nn.QuantizedLinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=new_vocab_size,
|
||||
bias="bias" in old_lm_head,
|
||||
group_size=old_lm_head.group_size,
|
||||
bits=old_lm_head.bits,
|
||||
)
|
||||
new_weights_lm = mx.zeros((new_vocab_size, input_dims))
|
||||
new_weights_lm[:min_vocab_size] = dequantized_lm_weights[
|
||||
:min_vocab_size
|
||||
]
|
||||
if new_vocab_size > output_dims:
|
||||
new_weights_lm[output_dims:] = mx.random.normal(
|
||||
(new_vocab_size - output_dims, input_dims), loc=0.0, scale=0.02
|
||||
)
|
||||
new_lm_head.weight, new_lm_head.scales, new_lm_head.biases = (
|
||||
mx.quantize(
|
||||
new_weights_lm, new_lm_head.group_size, new_lm_head.bits
|
||||
)
|
||||
)
|
||||
if "bias" in old_lm_head:
|
||||
new_lm_head.bias = mx.zeros((new_vocab_size,))
|
||||
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
|
||||
:min_vocab_size
|
||||
]
|
||||
else:
|
||||
# resize nn.Linear
|
||||
new_lm_head = nn.Linear(
|
||||
old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head
|
||||
)
|
||||
new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims))
|
||||
min_vocab_size = min(old_lm_head.weight.shape[0], new_vocab_size)
|
||||
new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size]
|
||||
if new_vocab_size > old_lm_head.weight.shape[0]:
|
||||
new_weights_lm[old_lm_head.weight.shape[0] :] = mx.random.normal(
|
||||
(
|
||||
new_vocab_size - old_lm_head.weight.shape[0],
|
||||
old_lm_head.input_dims,
|
||||
),
|
||||
loc=0.0,
|
||||
scale=0.02,
|
||||
)
|
||||
new_lm_head.weight = new_weights_lm
|
||||
# todo typechecking
|
||||
if "bias" in old_lm_head:
|
||||
new_lm_head.bias = mx.zeros((new_vocab_size,))
|
||||
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
|
||||
:min_vocab_size
|
||||
]
|
||||
|
||||
model.lm_head = new_lm_head
|
||||
else:
|
||||
print("Vocab already sized right.")
|
||||
return model
|
||||
|
||||
|
||||
def update_tokenizer(
|
||||
tokenizer: TokenizerWrapper, tokens: list[str], special: bool
|
||||
) -> TokenizerWrapper:
|
||||
"""
|
||||
Appends new tokens to the end of the tokenizer vocab
|
||||
"""
|
||||
if special:
|
||||
# todo TokenizerWrapper access method
|
||||
tokenizer._tokenizer.add_special_tokens({"additional_special_tokens": tokens})
|
||||
print(f"Tokenizer updated with special tokens: {tokens}")
|
||||
print(f"Tokenizer vocab size after append: {len(tokenizer._tokenizer)}")
|
||||
else:
|
||||
# todo add regular tokens
|
||||
pass
|
||||
return tokenizer
|
||||
|
||||
|
||||
def implement_new_tokens(
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
tokens: list[str],
|
||||
special: bool = False,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Update model`s tokenizer and embeddings with new tokens accordingly
|
||||
"""
|
||||
tokenizer = update_tokenizer(tokenizer=tokenizer, tokens=tokens, special=special)
|
||||
model = resize_embeddings(model=model, tokenizer=tokenizer)
|
||||
return model, tokenizer
|
@ -1,20 +1,20 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
import glob
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .datasets import CompletionsDataset
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
|
||||
def grad_checkpoint(layer):
|
||||
@ -64,32 +64,80 @@ class TrainingArgs:
|
||||
default=False,
|
||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||
)
|
||||
cot: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use CoT loss masking with positioning penalty"},
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, batch, lengths):
|
||||
inputs = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
steps = mx.arange(1, targets.shape[1] + 1)
|
||||
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||
ntoks = mask.sum()
|
||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||
ntoks = length_mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def iterate_batches(
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
max_seq_length,
|
||||
train=False,
|
||||
):
|
||||
@dataclass
|
||||
class CotTrainingArgs:
|
||||
cot: bool = False
|
||||
reasoning_token: str = "[REASONING]"
|
||||
data_token: str = "[DATA]"
|
||||
|
||||
|
||||
def cot_loss(
|
||||
model: nn.Module,
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
lengths: int,
|
||||
tokenizer: TokenizerWrapper,
|
||||
penalty: mx.float32 = 10.0,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
logits = model(inputs).astype(mx.float32)
|
||||
|
||||
reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0]
|
||||
data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0]
|
||||
|
||||
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
|
||||
data_positions = mx.argmax(targets == data_token_id, axis=1)
|
||||
|
||||
seq_indices = mx.arange(targets.shape[1])[None, :]
|
||||
|
||||
# base CoT mask: starts at [DATA]
|
||||
cot_mask = (seq_indices >= data_positions[:, None]).astype(mx.float32)
|
||||
|
||||
# length mask: limits to non-padded regions
|
||||
length_mask = (seq_indices < lengths[:, None]).astype(mx.float32)
|
||||
|
||||
# combine masks: only include tokens after [DATA] AND within sequence length
|
||||
loss_mask = cot_mask * length_mask
|
||||
|
||||
# validate sequence structure
|
||||
valid_seq = (
|
||||
(reasoning_positions < data_positions)
|
||||
& mx.any(targets == reasoning_token_id, axis=1)
|
||||
& mx.any(targets == data_token_id, axis=1)
|
||||
)
|
||||
|
||||
# compute base cross-entropy loss
|
||||
ce = nn.losses.cross_entropy(logits, targets)
|
||||
|
||||
# masking loss before [DATA]; applying penalty for invalid seq
|
||||
valid_loss = (ce * loss_mask).sum(axis=1) / (mx.sum(loss_mask, axis=1) + 1e-8)
|
||||
final_loss = mx.where(valid_seq, valid_loss, penalty) # 10.0 as invalid penalty
|
||||
loss = mx.mean(final_loss)
|
||||
|
||||
valid_tokens = mx.sum(loss_mask) + 1e-8
|
||||
|
||||
return loss, valid_tokens
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
if len(dataset) < batch_size:
|
||||
@ -114,10 +162,6 @@ def iterate_batches(
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
if len(batch[0]) == 2:
|
||||
batch, offsets = zip(*batch)
|
||||
else:
|
||||
offsets = [0] * len(batch)
|
||||
lengths = [len(x) for x in batch]
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
@ -140,7 +184,8 @@ def iterate_batches(
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
batch = mx.array(batch_arr)
|
||||
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
|
||||
if not train:
|
||||
break
|
||||
@ -156,8 +201,8 @@ def evaluate(
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
):
|
||||
all_losses = mx.array(0.0)
|
||||
ntokens = mx.array(0)
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
@ -213,6 +258,11 @@ def train(
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
|
||||
if args.cot:
|
||||
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0)
|
||||
else:
|
||||
loss = default_loss
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
@ -233,8 +283,8 @@ def train(
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
train_time = 0
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(1, args.iters + 1),
|
||||
iterate_batches(
|
||||
@ -245,11 +295,10 @@ def train(
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
tic = time.perf_counter()
|
||||
# Report validation loss if needed, the first validation loss
|
||||
# is always measured before any training.
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
tic = time.perf_counter()
|
||||
stop = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
@ -260,7 +309,7 @@ def train(
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - tic
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: "
|
||||
@ -277,23 +326,24 @@ def train(
|
||||
}
|
||||
training_callback.on_val_loss_report(val_info)
|
||||
|
||||
tic = time.perf_counter()
|
||||
start = time.perf_counter()
|
||||
|
||||
lvalue, toks = step(batch)
|
||||
losses += lvalue
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
mx.eval(state, losses, n_tokens)
|
||||
train_time += time.perf_counter() - tic
|
||||
|
||||
# Report training loss if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / train_time
|
||||
tokens_sec = float(n_tokens) / train_time
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
if rank == 0:
|
||||
@ -322,7 +372,7 @@ def train(
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
train_time = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
|
Loading…
Reference in New Issue
Block a user