added cot loss masking training

This commit is contained in:
paNikitin 2025-02-23 12:31:44 +03:00
parent 09b641aaa7
commit 68403f5577
4 changed files with 279 additions and 53 deletions

View File

@ -64,6 +64,12 @@ lora_parameters:
scale: 20.0
dropout: 0.0
# cot loss masking training
# cot:
# use_cot: true
# special: true
# additional_tokens: ["[REASONING]", "[DATA]"]
# Schedule can only be specified in a config file, uncomment to use.
#lr_schedule:
# name: cosine_decay

View File

@ -62,6 +62,7 @@ CONFIG_DEFAULTS = {
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"cot": False,
}
@ -78,7 +79,6 @@ def build_parser():
"--train",
action="store_true",
help="Do training",
default=None,
)
parser.add_argument(
"--data",
@ -94,14 +94,6 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)
parser.add_argument(
"--num-layers",
type=int,
@ -144,7 +136,6 @@ def build_parser():
"--test",
action="store_true",
help="Evaluate on the test set after training",
default=None,
)
parser.add_argument(
"--test-batches",
@ -166,9 +157,13 @@ def build_parser():
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument(
"--cot",
type=bool,
help="Use CoT loss masking",
)
return parser
@ -181,14 +176,8 @@ def train_model(
training_callback: TrainingCallback = None,
):
model.freeze()
if args.num_layers > len(model.layers):
raise ValueError(
f"Requested to train {args.num_layers} layers "
f"but the model only has {len(model.layers)} layers."
)
if args.fine_tune_type == "full":
for l in model.layers[-max(args.num_layers, 0) :]:
for l in model.layers[-min(args.num_layers, 0) :]:
l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process
@ -225,10 +214,13 @@ def train_model(
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
cot=(cot := args.cot),
)
model.train()
opt = optim.Adam(
# todo optimizer from args
opt = optim.AdamW(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
@ -269,6 +261,21 @@ def run(args, training_callback: TrainingCallback = None):
print("Loading pretrained model")
model, tokenizer = load(args.model)
if cot := args.cot:
print("Using CoT loss masking")
if tokens := cot.get("additional_tokens"):
from .tuner.new_tokens import implement_new_tokens
special = False
if (special_arg := cot.get("special")) and isinstance(special_arg, bool):
print("Updating model and tokenizer with new special tokens")
special = special_arg
else:
print("Updating model and tokenizer with new tokens")
model, tokenizer = implement_new_tokens(
model=model, tokenizer=tokenizer, tokens=tokens, special=special
)
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)
@ -293,6 +300,7 @@ def main():
parser = build_parser()
args = parser.parse_args()
config = args.config
args = vars(args)
if config:
print("Loading configuration file", config)

View File

@ -0,0 +1,162 @@
import mlx.nn as nn
import mlx.core as mx
from mlx_lm.tokenizer_utils import TokenizerWrapper
def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module:
"""
Resizes model embeddings to accommodate new tokens
"""
old_embedding = model.model.embed_tokens
old_vocab_size = old_embedding.num_embeddings
new_vocab_size = len(tokenizer._tokenizer)
if old_vocab_size != new_vocab_size:
if new_vocab_size < old_vocab_size:
print(
"Warning: New vocab size is smaller than original. Proceeding with trim."
)
# check if QuantizedEmbedding has required attributes for dequantization
try:
dequantized_weights = mx.dequantize(
old_embedding.weight,
scales=old_embedding.scales,
biases=old_embedding.biases,
group_size=old_embedding.group_size,
bits=old_embedding.bits,
)
except AttributeError as e:
print(f"Error: Cannot dequantize embed_tokens. Missing attributes: {e}")
print("Falling back to random weights for embed_tokens.")
dequantized_weights = mx.random.normal(
(old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02
)
# resize embed_tokens
new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims)
new_weights = mx.zeros((new_vocab_size, old_embedding.dims))
min_vocab_size = min(old_vocab_size, new_vocab_size)
new_weights[:min_vocab_size] = dequantized_weights[:min_vocab_size]
if new_vocab_size > old_vocab_size:
new_weights[old_vocab_size:] = mx.random.normal(
(new_vocab_size - old_vocab_size, old_embedding.dims),
loc=0.0,
scale=0.02,
)
new_embedding.weight = new_weights
model.model.embed_tokens = new_embedding
# attention layers handling
if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False):
model.model.embed_tokens.weight = new_weights
elif hasattr(model, "lm_head"):
old_lm_head = model.lm_head
if isinstance(old_lm_head, nn.QuantizedLinear):
# resize nn.QuantizedLinear
output_dims, compressed_input_dims = old_lm_head.weight.shape
bits = old_lm_head.bits
input_dims = compressed_input_dims * (32 // bits)
# dequantize lm_head weights
try:
dequantized_lm_weights = mx.dequantize(
old_lm_head.weight,
scales=old_lm_head.scales,
biases=old_lm_head.biases,
group_size=old_lm_head.group_size,
bits=old_lm_head.bits,
)
except AttributeError as e:
print(f"Error: Cannot dequantize lm_head. Missing attributes: {e}")
print("Falling back to random weights for lm_head.")
dequantized_lm_weights = mx.random.normal(
(output_dims, input_dims), loc=0.0, scale=0.02
)
new_lm_head = nn.QuantizedLinear(
input_dims=input_dims,
output_dims=new_vocab_size,
bias="bias" in old_lm_head,
group_size=old_lm_head.group_size,
bits=old_lm_head.bits,
)
new_weights_lm = mx.zeros((new_vocab_size, input_dims))
new_weights_lm[:min_vocab_size] = dequantized_lm_weights[
:min_vocab_size
]
if new_vocab_size > output_dims:
new_weights_lm[output_dims:] = mx.random.normal(
(new_vocab_size - output_dims, input_dims), loc=0.0, scale=0.02
)
new_lm_head.weight, new_lm_head.scales, new_lm_head.biases = (
mx.quantize(
new_weights_lm, new_lm_head.group_size, new_lm_head.bits
)
)
if "bias" in old_lm_head:
new_lm_head.bias = mx.zeros((new_vocab_size,))
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
:min_vocab_size
]
else:
# resize nn.Linear
new_lm_head = nn.Linear(
old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head
)
new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims))
min_vocab_size = min(old_lm_head.weight.shape[0], new_vocab_size)
new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size]
if new_vocab_size > old_lm_head.weight.shape[0]:
new_weights_lm[old_lm_head.weight.shape[0] :] = mx.random.normal(
(
new_vocab_size - old_lm_head.weight.shape[0],
old_lm_head.input_dims,
),
loc=0.0,
scale=0.02,
)
new_lm_head.weight = new_weights_lm
# todo typechecking
if "bias" in old_lm_head:
new_lm_head.bias = mx.zeros((new_vocab_size,))
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
:min_vocab_size
]
model.lm_head = new_lm_head
else:
print("Vocab already sized right.")
return model
def update_tokenizer(
tokenizer: TokenizerWrapper, tokens: list[str], special: bool
) -> TokenizerWrapper:
"""
Appends new tokens to the end of the tokenizer vocab
"""
if special:
# todo TokenizerWrapper access method
tokenizer._tokenizer.add_special_tokens({"additional_special_tokens": tokens})
print(f"Tokenizer updated with special tokens: {tokens}")
print(f"Tokenizer vocab size after append: {len(tokenizer._tokenizer)}")
else:
# todo add regular tokens
pass
return tokenizer
def implement_new_tokens(
model: nn.Module,
tokenizer: TokenizerWrapper,
tokens: list[str],
special: bool = False,
) -> tuple[nn.Module, TokenizerWrapper]:
"""
Update model`s tokenizer and embeddings with new tokens accordingly
"""
tokenizer = update_tokenizer(tokenizer=tokenizer, tokens=tokens, special=special)
model = resize_embeddings(model=model, tokenizer=tokenizer)
return model, tokenizer

View File

@ -1,20 +1,20 @@
# Copyright © 2024 Apple Inc.
from functools import partial
import glob
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from .datasets import CompletionsDataset
from mlx_lm.tokenizer_utils import TokenizerWrapper
def grad_checkpoint(layer):
@ -64,32 +64,80 @@ class TrainingArgs:
default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."},
)
cot: bool = field(
default=False,
metadata={"help": "Use CoT loss masking with positioning penalty"},
)
def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
def default_loss(model, inputs, targets, lengths):
logits = model(inputs)
logits = logits.astype(mx.float32)
steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = mask.sum()
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def iterate_batches(
dataset,
tokenizer,
batch_size,
max_seq_length,
train=False,
):
@dataclass
class CotTrainingArgs:
cot: bool = False
reasoning_token: str = "[REASONING]"
data_token: str = "[DATA]"
def cot_loss(
model: nn.Module,
inputs: mx.array,
targets: mx.array,
lengths: int,
tokenizer: TokenizerWrapper,
penalty: mx.float32 = 10.0,
) -> tuple[mx.array, mx.array]:
logits = model(inputs).astype(mx.float32)
reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0]
data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0]
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
data_positions = mx.argmax(targets == data_token_id, axis=1)
seq_indices = mx.arange(targets.shape[1])[None, :]
# base CoT mask: starts at [DATA]
cot_mask = (seq_indices >= data_positions[:, None]).astype(mx.float32)
# length mask: limits to non-padded regions
length_mask = (seq_indices < lengths[:, None]).astype(mx.float32)
# combine masks: only include tokens after [DATA] AND within sequence length
loss_mask = cot_mask * length_mask
# validate sequence structure
valid_seq = (
(reasoning_positions < data_positions)
& mx.any(targets == reasoning_token_id, axis=1)
& mx.any(targets == data_token_id, axis=1)
)
# compute base cross-entropy loss
ce = nn.losses.cross_entropy(logits, targets)
# masking loss before [DATA]; applying penalty for invalid seq
valid_loss = (ce * loss_mask).sum(axis=1) / (mx.sum(loss_mask, axis=1) + 1e-8)
final_loss = mx.where(valid_seq, valid_loss, penalty) # 10.0 as invalid penalty
loss = mx.mean(final_loss)
valid_tokens = mx.sum(loss_mask) + 1e-8
return loss, valid_tokens
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
@ -114,10 +162,6 @@ def iterate_batches(
indices = np.random.permutation(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
@ -140,7 +184,8 @@ def iterate_batches(
truncated_length # Update lengths to match truncated lengths
)
batch = mx.array(batch_arr)
yield batch, mx.array(list(zip(offsets, lengths)))
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train:
break
@ -156,8 +201,8 @@ def evaluate(
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
):
all_losses = mx.array(0.0)
ntokens = mx.array(0)
all_losses = 0
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -213,6 +258,11 @@ def train(
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
if args.cot:
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0)
else:
loss = default_loss
state = [model.state, optimizer.state]
def step(batch):
@ -233,8 +283,8 @@ def train(
n_tokens = 0
steps = 0
trained_tokens = 0
train_time = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_batches(
@ -245,11 +295,10 @@ def train(
train=True,
),
):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
tic = time.perf_counter()
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
@ -260,7 +309,7 @@ def train(
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - tic
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
@ -277,23 +326,24 @@ def train(
}
training_callback.on_val_loss_report(val_info)
tic = time.perf_counter()
start = time.perf_counter()
lvalue, toks = step(batch)
losses += lvalue
n_tokens += toks
steps += 1
mx.eval(state, losses, n_tokens)
train_time += time.perf_counter() - tic
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / train_time
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
@ -322,7 +372,7 @@ def train(
losses = 0
n_tokens = 0
steps = 0
train_time = 0
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0: