mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 09:56:24 +08:00
added cot loss masking training
This commit is contained in:
parent
09b641aaa7
commit
68403f5577
@ -64,6 +64,12 @@ lora_parameters:
|
|||||||
scale: 20.0
|
scale: 20.0
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
|
|
||||||
|
# cot loss masking training
|
||||||
|
# cot:
|
||||||
|
# use_cot: true
|
||||||
|
# special: true
|
||||||
|
# additional_tokens: ["[REASONING]", "[DATA]"]
|
||||||
|
|
||||||
# Schedule can only be specified in a config file, uncomment to use.
|
# Schedule can only be specified in a config file, uncomment to use.
|
||||||
#lr_schedule:
|
#lr_schedule:
|
||||||
# name: cosine_decay
|
# name: cosine_decay
|
||||||
|
@ -62,6 +62,7 @@ CONFIG_DEFAULTS = {
|
|||||||
"grad_checkpoint": False,
|
"grad_checkpoint": False,
|
||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
|
"cot": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -78,7 +79,6 @@ def build_parser():
|
|||||||
"--train",
|
"--train",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Do training",
|
help="Do training",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
@ -94,14 +94,6 @@ def build_parser():
|
|||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--mask-prompt",
|
|
||||||
action="store_true",
|
|
||||||
help="Mask the prompt in the loss when training",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-layers",
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -144,7 +136,6 @@ def build_parser():
|
|||||||
"--test",
|
"--test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Evaluate on the test set after training",
|
help="Evaluate on the test set after training",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test-batches",
|
"--test-batches",
|
||||||
@ -166,9 +157,13 @@ def build_parser():
|
|||||||
"--grad-checkpoint",
|
"--grad-checkpoint",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use gradient checkpointing to reduce memory use.",
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cot",
|
||||||
|
type=bool,
|
||||||
|
help="Use CoT loss masking",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -181,14 +176,8 @@ def train_model(
|
|||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
model.freeze()
|
model.freeze()
|
||||||
if args.num_layers > len(model.layers):
|
|
||||||
raise ValueError(
|
|
||||||
f"Requested to train {args.num_layers} layers "
|
|
||||||
f"but the model only has {len(model.layers)} layers."
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.fine_tune_type == "full":
|
if args.fine_tune_type == "full":
|
||||||
for l in model.layers[-max(args.num_layers, 0) :]:
|
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||||
l.unfreeze()
|
l.unfreeze()
|
||||||
elif args.fine_tune_type in ["lora", "dora"]:
|
elif args.fine_tune_type in ["lora", "dora"]:
|
||||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||||
@ -225,10 +214,13 @@ def train_model(
|
|||||||
adapter_file=adapter_file,
|
adapter_file=adapter_file,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
grad_checkpoint=args.grad_checkpoint,
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
|
cot=(cot := args.cot),
|
||||||
)
|
)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
opt = optim.Adam(
|
# todo optimizer from args
|
||||||
|
|
||||||
|
opt = optim.AdamW(
|
||||||
learning_rate=(
|
learning_rate=(
|
||||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||||
)
|
)
|
||||||
@ -269,6 +261,21 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
print("Loading pretrained model")
|
print("Loading pretrained model")
|
||||||
model, tokenizer = load(args.model)
|
model, tokenizer = load(args.model)
|
||||||
|
|
||||||
|
if cot := args.cot:
|
||||||
|
print("Using CoT loss masking")
|
||||||
|
if tokens := cot.get("additional_tokens"):
|
||||||
|
from .tuner.new_tokens import implement_new_tokens
|
||||||
|
|
||||||
|
special = False
|
||||||
|
if (special_arg := cot.get("special")) and isinstance(special_arg, bool):
|
||||||
|
print("Updating model and tokenizer with new special tokens")
|
||||||
|
special = special_arg
|
||||||
|
else:
|
||||||
|
print("Updating model and tokenizer with new tokens")
|
||||||
|
model, tokenizer = implement_new_tokens(
|
||||||
|
model=model, tokenizer=tokenizer, tokens=tokens, special=special
|
||||||
|
)
|
||||||
|
|
||||||
print("Loading datasets")
|
print("Loading datasets")
|
||||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||||
|
|
||||||
@ -293,6 +300,7 @@ def main():
|
|||||||
parser = build_parser()
|
parser = build_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
config = args.config
|
config = args.config
|
||||||
|
|
||||||
args = vars(args)
|
args = vars(args)
|
||||||
if config:
|
if config:
|
||||||
print("Loading configuration file", config)
|
print("Loading configuration file", config)
|
||||||
|
162
llms/mlx_lm/tuner/new_tokens.py
Normal file
162
llms/mlx_lm/tuner/new_tokens.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
|
||||||
|
def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Resizes model embeddings to accommodate new tokens
|
||||||
|
"""
|
||||||
|
old_embedding = model.model.embed_tokens
|
||||||
|
|
||||||
|
old_vocab_size = old_embedding.num_embeddings
|
||||||
|
new_vocab_size = len(tokenizer._tokenizer)
|
||||||
|
|
||||||
|
if old_vocab_size != new_vocab_size:
|
||||||
|
if new_vocab_size < old_vocab_size:
|
||||||
|
print(
|
||||||
|
"Warning: New vocab size is smaller than original. Proceeding with trim."
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if QuantizedEmbedding has required attributes for dequantization
|
||||||
|
try:
|
||||||
|
dequantized_weights = mx.dequantize(
|
||||||
|
old_embedding.weight,
|
||||||
|
scales=old_embedding.scales,
|
||||||
|
biases=old_embedding.biases,
|
||||||
|
group_size=old_embedding.group_size,
|
||||||
|
bits=old_embedding.bits,
|
||||||
|
)
|
||||||
|
except AttributeError as e:
|
||||||
|
print(f"Error: Cannot dequantize embed_tokens. Missing attributes: {e}")
|
||||||
|
print("Falling back to random weights for embed_tokens.")
|
||||||
|
dequantized_weights = mx.random.normal(
|
||||||
|
(old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02
|
||||||
|
)
|
||||||
|
|
||||||
|
# resize embed_tokens
|
||||||
|
new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims)
|
||||||
|
new_weights = mx.zeros((new_vocab_size, old_embedding.dims))
|
||||||
|
min_vocab_size = min(old_vocab_size, new_vocab_size)
|
||||||
|
new_weights[:min_vocab_size] = dequantized_weights[:min_vocab_size]
|
||||||
|
if new_vocab_size > old_vocab_size:
|
||||||
|
new_weights[old_vocab_size:] = mx.random.normal(
|
||||||
|
(new_vocab_size - old_vocab_size, old_embedding.dims),
|
||||||
|
loc=0.0,
|
||||||
|
scale=0.02,
|
||||||
|
)
|
||||||
|
new_embedding.weight = new_weights
|
||||||
|
model.model.embed_tokens = new_embedding
|
||||||
|
|
||||||
|
# attention layers handling
|
||||||
|
if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False):
|
||||||
|
model.model.embed_tokens.weight = new_weights
|
||||||
|
elif hasattr(model, "lm_head"):
|
||||||
|
old_lm_head = model.lm_head
|
||||||
|
if isinstance(old_lm_head, nn.QuantizedLinear):
|
||||||
|
# resize nn.QuantizedLinear
|
||||||
|
output_dims, compressed_input_dims = old_lm_head.weight.shape
|
||||||
|
bits = old_lm_head.bits
|
||||||
|
input_dims = compressed_input_dims * (32 // bits)
|
||||||
|
|
||||||
|
# dequantize lm_head weights
|
||||||
|
try:
|
||||||
|
dequantized_lm_weights = mx.dequantize(
|
||||||
|
old_lm_head.weight,
|
||||||
|
scales=old_lm_head.scales,
|
||||||
|
biases=old_lm_head.biases,
|
||||||
|
group_size=old_lm_head.group_size,
|
||||||
|
bits=old_lm_head.bits,
|
||||||
|
)
|
||||||
|
except AttributeError as e:
|
||||||
|
print(f"Error: Cannot dequantize lm_head. Missing attributes: {e}")
|
||||||
|
print("Falling back to random weights for lm_head.")
|
||||||
|
dequantized_lm_weights = mx.random.normal(
|
||||||
|
(output_dims, input_dims), loc=0.0, scale=0.02
|
||||||
|
)
|
||||||
|
|
||||||
|
new_lm_head = nn.QuantizedLinear(
|
||||||
|
input_dims=input_dims,
|
||||||
|
output_dims=new_vocab_size,
|
||||||
|
bias="bias" in old_lm_head,
|
||||||
|
group_size=old_lm_head.group_size,
|
||||||
|
bits=old_lm_head.bits,
|
||||||
|
)
|
||||||
|
new_weights_lm = mx.zeros((new_vocab_size, input_dims))
|
||||||
|
new_weights_lm[:min_vocab_size] = dequantized_lm_weights[
|
||||||
|
:min_vocab_size
|
||||||
|
]
|
||||||
|
if new_vocab_size > output_dims:
|
||||||
|
new_weights_lm[output_dims:] = mx.random.normal(
|
||||||
|
(new_vocab_size - output_dims, input_dims), loc=0.0, scale=0.02
|
||||||
|
)
|
||||||
|
new_lm_head.weight, new_lm_head.scales, new_lm_head.biases = (
|
||||||
|
mx.quantize(
|
||||||
|
new_weights_lm, new_lm_head.group_size, new_lm_head.bits
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "bias" in old_lm_head:
|
||||||
|
new_lm_head.bias = mx.zeros((new_vocab_size,))
|
||||||
|
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
|
||||||
|
:min_vocab_size
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# resize nn.Linear
|
||||||
|
new_lm_head = nn.Linear(
|
||||||
|
old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head
|
||||||
|
)
|
||||||
|
new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims))
|
||||||
|
min_vocab_size = min(old_lm_head.weight.shape[0], new_vocab_size)
|
||||||
|
new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size]
|
||||||
|
if new_vocab_size > old_lm_head.weight.shape[0]:
|
||||||
|
new_weights_lm[old_lm_head.weight.shape[0] :] = mx.random.normal(
|
||||||
|
(
|
||||||
|
new_vocab_size - old_lm_head.weight.shape[0],
|
||||||
|
old_lm_head.input_dims,
|
||||||
|
),
|
||||||
|
loc=0.0,
|
||||||
|
scale=0.02,
|
||||||
|
)
|
||||||
|
new_lm_head.weight = new_weights_lm
|
||||||
|
# todo typechecking
|
||||||
|
if "bias" in old_lm_head:
|
||||||
|
new_lm_head.bias = mx.zeros((new_vocab_size,))
|
||||||
|
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
|
||||||
|
:min_vocab_size
|
||||||
|
]
|
||||||
|
|
||||||
|
model.lm_head = new_lm_head
|
||||||
|
else:
|
||||||
|
print("Vocab already sized right.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def update_tokenizer(
|
||||||
|
tokenizer: TokenizerWrapper, tokens: list[str], special: bool
|
||||||
|
) -> TokenizerWrapper:
|
||||||
|
"""
|
||||||
|
Appends new tokens to the end of the tokenizer vocab
|
||||||
|
"""
|
||||||
|
if special:
|
||||||
|
# todo TokenizerWrapper access method
|
||||||
|
tokenizer._tokenizer.add_special_tokens({"additional_special_tokens": tokens})
|
||||||
|
print(f"Tokenizer updated with special tokens: {tokens}")
|
||||||
|
print(f"Tokenizer vocab size after append: {len(tokenizer._tokenizer)}")
|
||||||
|
else:
|
||||||
|
# todo add regular tokens
|
||||||
|
pass
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def implement_new_tokens(
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
tokens: list[str],
|
||||||
|
special: bool = False,
|
||||||
|
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||||
|
"""
|
||||||
|
Update model`s tokenizer and embeddings with new tokens accordingly
|
||||||
|
"""
|
||||||
|
tokenizer = update_tokenizer(tokenizer=tokenizer, tokens=tokens, special=special)
|
||||||
|
model = resize_embeddings(model=model, tokenizer=tokenizer)
|
||||||
|
return model, tokenizer
|
@ -1,20 +1,20 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
import glob
|
import glob
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from transformers import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from .datasets import CompletionsDataset
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
|
||||||
def grad_checkpoint(layer):
|
def grad_checkpoint(layer):
|
||||||
@ -64,32 +64,80 @@ class TrainingArgs:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||||
)
|
)
|
||||||
|
cot: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use CoT loss masking with positioning penalty"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_loss(model, batch, lengths):
|
def default_loss(model, inputs, targets, lengths):
|
||||||
inputs = batch[:, :-1]
|
|
||||||
targets = batch[:, 1:]
|
|
||||||
|
|
||||||
logits = model(inputs)
|
logits = model(inputs)
|
||||||
logits = logits.astype(mx.float32)
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
steps = mx.arange(1, targets.shape[1] + 1)
|
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||||
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
|
||||||
|
|
||||||
ce = nn.losses.cross_entropy(logits, targets) * mask
|
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||||
ntoks = mask.sum()
|
ntoks = length_mask.sum()
|
||||||
ce = ce.sum() / ntoks
|
ce = ce.sum() / ntoks
|
||||||
|
|
||||||
return ce, ntoks
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(
|
@dataclass
|
||||||
dataset,
|
class CotTrainingArgs:
|
||||||
tokenizer,
|
cot: bool = False
|
||||||
batch_size,
|
reasoning_token: str = "[REASONING]"
|
||||||
max_seq_length,
|
data_token: str = "[DATA]"
|
||||||
train=False,
|
|
||||||
):
|
|
||||||
|
def cot_loss(
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
lengths: int,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
penalty: mx.float32 = 10.0,
|
||||||
|
) -> tuple[mx.array, mx.array]:
|
||||||
|
logits = model(inputs).astype(mx.float32)
|
||||||
|
|
||||||
|
reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0]
|
||||||
|
data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0]
|
||||||
|
|
||||||
|
reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1)
|
||||||
|
data_positions = mx.argmax(targets == data_token_id, axis=1)
|
||||||
|
|
||||||
|
seq_indices = mx.arange(targets.shape[1])[None, :]
|
||||||
|
|
||||||
|
# base CoT mask: starts at [DATA]
|
||||||
|
cot_mask = (seq_indices >= data_positions[:, None]).astype(mx.float32)
|
||||||
|
|
||||||
|
# length mask: limits to non-padded regions
|
||||||
|
length_mask = (seq_indices < lengths[:, None]).astype(mx.float32)
|
||||||
|
|
||||||
|
# combine masks: only include tokens after [DATA] AND within sequence length
|
||||||
|
loss_mask = cot_mask * length_mask
|
||||||
|
|
||||||
|
# validate sequence structure
|
||||||
|
valid_seq = (
|
||||||
|
(reasoning_positions < data_positions)
|
||||||
|
& mx.any(targets == reasoning_token_id, axis=1)
|
||||||
|
& mx.any(targets == data_token_id, axis=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute base cross-entropy loss
|
||||||
|
ce = nn.losses.cross_entropy(logits, targets)
|
||||||
|
|
||||||
|
# masking loss before [DATA]; applying penalty for invalid seq
|
||||||
|
valid_loss = (ce * loss_mask).sum(axis=1) / (mx.sum(loss_mask, axis=1) + 1e-8)
|
||||||
|
final_loss = mx.where(valid_seq, valid_loss, penalty) # 10.0 as invalid penalty
|
||||||
|
loss = mx.mean(final_loss)
|
||||||
|
|
||||||
|
valid_tokens = mx.sum(loss_mask) + 1e-8
|
||||||
|
|
||||||
|
return loss, valid_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
# Sort by length:
|
# Sort by length:
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
@ -114,10 +162,6 @@ def iterate_batches(
|
|||||||
indices = np.random.permutation(len(batch_idx))
|
indices = np.random.permutation(len(batch_idx))
|
||||||
for i in indices:
|
for i in indices:
|
||||||
batch = [dataset[j] for j in batch_idx[i]]
|
batch = [dataset[j] for j in batch_idx[i]]
|
||||||
if len(batch[0]) == 2:
|
|
||||||
batch, offsets = zip(*batch)
|
|
||||||
else:
|
|
||||||
offsets = [0] * len(batch)
|
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
if max(lengths) > max_seq_length:
|
if max(lengths) > max_seq_length:
|
||||||
print(
|
print(
|
||||||
@ -140,7 +184,8 @@ def iterate_batches(
|
|||||||
truncated_length # Update lengths to match truncated lengths
|
truncated_length # Update lengths to match truncated lengths
|
||||||
)
|
)
|
||||||
batch = mx.array(batch_arr)
|
batch = mx.array(batch_arr)
|
||||||
yield batch, mx.array(list(zip(offsets, lengths)))
|
|
||||||
|
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
@ -156,8 +201,8 @@ def evaluate(
|
|||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = mx.array(0.0)
|
all_losses = 0
|
||||||
ntokens = mx.array(0)
|
ntokens = 0
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
@ -213,6 +258,11 @@ def train(
|
|||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
grad_checkpoint(model.layers[0])
|
grad_checkpoint(model.layers[0])
|
||||||
|
|
||||||
|
if args.cot:
|
||||||
|
loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0)
|
||||||
|
else:
|
||||||
|
loss = default_loss
|
||||||
|
|
||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
def step(batch):
|
def step(batch):
|
||||||
@ -233,8 +283,8 @@ def train(
|
|||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
trained_tokens = 0
|
trained_tokens = 0
|
||||||
train_time = 0
|
|
||||||
# Main training loop
|
# Main training loop
|
||||||
|
start = time.perf_counter()
|
||||||
for it, batch in zip(
|
for it, batch in zip(
|
||||||
range(1, args.iters + 1),
|
range(1, args.iters + 1),
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
@ -245,11 +295,10 @@ def train(
|
|||||||
train=True,
|
train=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
tic = time.perf_counter()
|
|
||||||
# Report validation loss if needed, the first validation loss
|
# Report validation loss if needed, the first validation loss
|
||||||
# is always measured before any training.
|
# is always measured before any training.
|
||||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
tic = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
val_loss = evaluate(
|
val_loss = evaluate(
|
||||||
model=model,
|
model=model,
|
||||||
dataset=val_dataset,
|
dataset=val_dataset,
|
||||||
@ -260,7 +309,7 @@ def train(
|
|||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
iterate_batches=iterate_batches,
|
iterate_batches=iterate_batches,
|
||||||
)
|
)
|
||||||
val_time = time.perf_counter() - tic
|
val_time = time.perf_counter() - stop
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: "
|
f"Iter {it}: "
|
||||||
@ -277,23 +326,24 @@ def train(
|
|||||||
}
|
}
|
||||||
training_callback.on_val_loss_report(val_info)
|
training_callback.on_val_loss_report(val_info)
|
||||||
|
|
||||||
tic = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
lvalue, toks = step(batch)
|
lvalue, toks = step(batch)
|
||||||
losses += lvalue
|
losses += lvalue
|
||||||
n_tokens += toks
|
n_tokens += toks
|
||||||
steps += 1
|
steps += 1
|
||||||
mx.eval(state, losses, n_tokens)
|
mx.eval(state, losses, n_tokens)
|
||||||
train_time += time.perf_counter() - tic
|
|
||||||
|
|
||||||
# Report training loss if needed
|
# Report training loss if needed
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
|
stop = time.perf_counter()
|
||||||
|
|
||||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||||
train_loss /= steps * mx.distributed.init().size()
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / train_time
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / train_time
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -322,7 +372,7 @@ def train(
|
|||||||
losses = 0
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
train_time = 0
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Save adapter weights
|
# Save adapter weights
|
||||||
if it % args.steps_per_save == 0:
|
if it % args.steps_per_save == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user