mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
added cot loss masking training
This commit is contained in:
162
llms/mlx_lm/tuner/new_tokens.py
Normal file
162
llms/mlx_lm/tuner/new_tokens.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import mlx.nn as nn
|
||||
import mlx.core as mx
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
|
||||
def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module:
|
||||
"""
|
||||
Resizes model embeddings to accommodate new tokens
|
||||
"""
|
||||
old_embedding = model.model.embed_tokens
|
||||
|
||||
old_vocab_size = old_embedding.num_embeddings
|
||||
new_vocab_size = len(tokenizer._tokenizer)
|
||||
|
||||
if old_vocab_size != new_vocab_size:
|
||||
if new_vocab_size < old_vocab_size:
|
||||
print(
|
||||
"Warning: New vocab size is smaller than original. Proceeding with trim."
|
||||
)
|
||||
|
||||
# check if QuantizedEmbedding has required attributes for dequantization
|
||||
try:
|
||||
dequantized_weights = mx.dequantize(
|
||||
old_embedding.weight,
|
||||
scales=old_embedding.scales,
|
||||
biases=old_embedding.biases,
|
||||
group_size=old_embedding.group_size,
|
||||
bits=old_embedding.bits,
|
||||
)
|
||||
except AttributeError as e:
|
||||
print(f"Error: Cannot dequantize embed_tokens. Missing attributes: {e}")
|
||||
print("Falling back to random weights for embed_tokens.")
|
||||
dequantized_weights = mx.random.normal(
|
||||
(old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02
|
||||
)
|
||||
|
||||
# resize embed_tokens
|
||||
new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims)
|
||||
new_weights = mx.zeros((new_vocab_size, old_embedding.dims))
|
||||
min_vocab_size = min(old_vocab_size, new_vocab_size)
|
||||
new_weights[:min_vocab_size] = dequantized_weights[:min_vocab_size]
|
||||
if new_vocab_size > old_vocab_size:
|
||||
new_weights[old_vocab_size:] = mx.random.normal(
|
||||
(new_vocab_size - old_vocab_size, old_embedding.dims),
|
||||
loc=0.0,
|
||||
scale=0.02,
|
||||
)
|
||||
new_embedding.weight = new_weights
|
||||
model.model.embed_tokens = new_embedding
|
||||
|
||||
# attention layers handling
|
||||
if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False):
|
||||
model.model.embed_tokens.weight = new_weights
|
||||
elif hasattr(model, "lm_head"):
|
||||
old_lm_head = model.lm_head
|
||||
if isinstance(old_lm_head, nn.QuantizedLinear):
|
||||
# resize nn.QuantizedLinear
|
||||
output_dims, compressed_input_dims = old_lm_head.weight.shape
|
||||
bits = old_lm_head.bits
|
||||
input_dims = compressed_input_dims * (32 // bits)
|
||||
|
||||
# dequantize lm_head weights
|
||||
try:
|
||||
dequantized_lm_weights = mx.dequantize(
|
||||
old_lm_head.weight,
|
||||
scales=old_lm_head.scales,
|
||||
biases=old_lm_head.biases,
|
||||
group_size=old_lm_head.group_size,
|
||||
bits=old_lm_head.bits,
|
||||
)
|
||||
except AttributeError as e:
|
||||
print(f"Error: Cannot dequantize lm_head. Missing attributes: {e}")
|
||||
print("Falling back to random weights for lm_head.")
|
||||
dequantized_lm_weights = mx.random.normal(
|
||||
(output_dims, input_dims), loc=0.0, scale=0.02
|
||||
)
|
||||
|
||||
new_lm_head = nn.QuantizedLinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=new_vocab_size,
|
||||
bias="bias" in old_lm_head,
|
||||
group_size=old_lm_head.group_size,
|
||||
bits=old_lm_head.bits,
|
||||
)
|
||||
new_weights_lm = mx.zeros((new_vocab_size, input_dims))
|
||||
new_weights_lm[:min_vocab_size] = dequantized_lm_weights[
|
||||
:min_vocab_size
|
||||
]
|
||||
if new_vocab_size > output_dims:
|
||||
new_weights_lm[output_dims:] = mx.random.normal(
|
||||
(new_vocab_size - output_dims, input_dims), loc=0.0, scale=0.02
|
||||
)
|
||||
new_lm_head.weight, new_lm_head.scales, new_lm_head.biases = (
|
||||
mx.quantize(
|
||||
new_weights_lm, new_lm_head.group_size, new_lm_head.bits
|
||||
)
|
||||
)
|
||||
if "bias" in old_lm_head:
|
||||
new_lm_head.bias = mx.zeros((new_vocab_size,))
|
||||
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
|
||||
:min_vocab_size
|
||||
]
|
||||
else:
|
||||
# resize nn.Linear
|
||||
new_lm_head = nn.Linear(
|
||||
old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head
|
||||
)
|
||||
new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims))
|
||||
min_vocab_size = min(old_lm_head.weight.shape[0], new_vocab_size)
|
||||
new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size]
|
||||
if new_vocab_size > old_lm_head.weight.shape[0]:
|
||||
new_weights_lm[old_lm_head.weight.shape[0] :] = mx.random.normal(
|
||||
(
|
||||
new_vocab_size - old_lm_head.weight.shape[0],
|
||||
old_lm_head.input_dims,
|
||||
),
|
||||
loc=0.0,
|
||||
scale=0.02,
|
||||
)
|
||||
new_lm_head.weight = new_weights_lm
|
||||
# todo typechecking
|
||||
if "bias" in old_lm_head:
|
||||
new_lm_head.bias = mx.zeros((new_vocab_size,))
|
||||
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
|
||||
:min_vocab_size
|
||||
]
|
||||
|
||||
model.lm_head = new_lm_head
|
||||
else:
|
||||
print("Vocab already sized right.")
|
||||
return model
|
||||
|
||||
|
||||
def update_tokenizer(
|
||||
tokenizer: TokenizerWrapper, tokens: list[str], special: bool
|
||||
) -> TokenizerWrapper:
|
||||
"""
|
||||
Appends new tokens to the end of the tokenizer vocab
|
||||
"""
|
||||
if special:
|
||||
# todo TokenizerWrapper access method
|
||||
tokenizer._tokenizer.add_special_tokens({"additional_special_tokens": tokens})
|
||||
print(f"Tokenizer updated with special tokens: {tokens}")
|
||||
print(f"Tokenizer vocab size after append: {len(tokenizer._tokenizer)}")
|
||||
else:
|
||||
# todo add regular tokens
|
||||
pass
|
||||
return tokenizer
|
||||
|
||||
|
||||
def implement_new_tokens(
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
tokens: list[str],
|
||||
special: bool = False,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Update model`s tokenizer and embeddings with new tokens accordingly
|
||||
"""
|
||||
tokenizer = update_tokenizer(tokenizer=tokenizer, tokens=tokens, special=special)
|
||||
model = resize_embeddings(model=model, tokenizer=tokenizer)
|
||||
return model, tokenizer
|
@@ -1,20 +1,20 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
import glob
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .datasets import CompletionsDataset
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
|
||||
def grad_checkpoint(layer):
|
||||
@@ -64,32 +64,80 @@ class TrainingArgs:
|
||||
default=False,
|
||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||
)
|
||||
cot: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use CoT loss masking with positioning penalty"},
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, batch, lengths):
|
||||
inputs = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
steps = mx.arange(1, targets.shape[1] + 1)
|
||||
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||
ntoks = mask.sum()
|
||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||
ntoks = length_mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def iterate_batches(
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
max_seq_length,
|
||||
train=False,
|
||||
):
|
||||
@dataclass
|
||||
class CotTrainingArgs:
|
||||
cot: bool = False
|
||||
reasoning_token: str = "[REASONING]"
|
||||
data_token: str = "[DATA]"
|
||||
|
||||
|
||||
def cot_loss(
|
||||
model: nn.Module,
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
lengths: int,
|
||||
tokenizer: TokenizerWrapper,
|
||||
penalty: mx.float32 = 10.0,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
logits = model(inputs).astype(mx.float32)
|
||||
|
||||
reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0]
|
||||
data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0]
|
||||
|
||||
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
|
||||
data_positions = mx.argmax(targets == data_token_id, axis=1)
|
||||
|
||||
seq_indices = mx.arange(targets.shape[1])[None, :]
|
||||
|
||||
# base CoT mask: starts at [DATA]
|
||||
cot_mask = (seq_indices >= data_positions[:, None]).astype(mx.float32)
|
||||
|
||||
# length mask: limits to non-padded regions
|
||||
length_mask = (seq_indices < lengths[:, None]).astype(mx.float32)
|
||||
|
||||
# combine masks: only include tokens after [DATA] AND within sequence length
|
||||
loss_mask = cot_mask * length_mask
|
||||
|
||||
# validate sequence structure
|
||||
valid_seq = (
|
||||
(reasoning_positions < data_positions)
|
||||
& mx.any(targets == reasoning_token_id, axis=1)
|
||||
& mx.any(targets == data_token_id, axis=1)
|
||||
)
|
||||
|
||||
# compute base cross-entropy loss
|
||||
ce = nn.losses.cross_entropy(logits, targets)
|
||||
|
||||
# masking loss before [DATA]; applying penalty for invalid seq
|
||||
valid_loss = (ce * loss_mask).sum(axis=1) / (mx.sum(loss_mask, axis=1) + 1e-8)
|
||||
final_loss = mx.where(valid_seq, valid_loss, penalty) # 10.0 as invalid penalty
|
||||
loss = mx.mean(final_loss)
|
||||
|
||||
valid_tokens = mx.sum(loss_mask) + 1e-8
|
||||
|
||||
return loss, valid_tokens
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
if len(dataset) < batch_size:
|
||||
@@ -114,10 +162,6 @@ def iterate_batches(
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
if len(batch[0]) == 2:
|
||||
batch, offsets = zip(*batch)
|
||||
else:
|
||||
offsets = [0] * len(batch)
|
||||
lengths = [len(x) for x in batch]
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
@@ -140,7 +184,8 @@ def iterate_batches(
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
batch = mx.array(batch_arr)
|
||||
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
|
||||
if not train:
|
||||
break
|
||||
@@ -156,8 +201,8 @@ def evaluate(
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
):
|
||||
all_losses = mx.array(0.0)
|
||||
ntokens = mx.array(0)
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
@@ -213,6 +258,11 @@ def train(
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
|
||||
if args.cot:
|
||||
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0)
|
||||
else:
|
||||
loss = default_loss
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
@@ -233,8 +283,8 @@ def train(
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
train_time = 0
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(1, args.iters + 1),
|
||||
iterate_batches(
|
||||
@@ -245,11 +295,10 @@ def train(
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
tic = time.perf_counter()
|
||||
# Report validation loss if needed, the first validation loss
|
||||
# is always measured before any training.
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
tic = time.perf_counter()
|
||||
stop = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
@@ -260,7 +309,7 @@ def train(
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - tic
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: "
|
||||
@@ -277,23 +326,24 @@ def train(
|
||||
}
|
||||
training_callback.on_val_loss_report(val_info)
|
||||
|
||||
tic = time.perf_counter()
|
||||
start = time.perf_counter()
|
||||
|
||||
lvalue, toks = step(batch)
|
||||
losses += lvalue
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
mx.eval(state, losses, n_tokens)
|
||||
train_time += time.perf_counter() - tic
|
||||
|
||||
# Report training loss if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / train_time
|
||||
tokens_sec = float(n_tokens) / train_time
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
if rank == 0:
|
||||
@@ -322,7 +372,7 @@ def train(
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
train_time = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
|
Reference in New Issue
Block a user