update upstream

This commit is contained in:
paNikitin 2025-02-23 12:48:25 +03:00
parent 0f790c4c84
commit a2b61afd05
2 changed files with 40 additions and 16 deletions

View File

@ -79,6 +79,7 @@ def build_parser():
"--train", "--train",
action="store_true", action="store_true",
help="Do training", help="Do training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--data", "--data",
@ -94,6 +95,12 @@ def build_parser():
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)
parser.add_argument( parser.add_argument(
"--num-layers", "--num-layers",
type=int, type=int,
@ -136,6 +143,7 @@ def build_parser():
"--test", "--test",
action="store_true", action="store_true",
help="Evaluate on the test set after training", help="Evaluate on the test set after training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--test-batches", "--test-batches",
@ -157,6 +165,7 @@ def build_parser():
"--grad-checkpoint", "--grad-checkpoint",
action="store_true", action="store_true",
help="Use gradient checkpointing to reduce memory use.", help="Use gradient checkpointing to reduce memory use.",
default=None,
) )
parser.add_argument("--seed", type=int, help="The PRNG seed") parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument( parser.add_argument(
@ -176,6 +185,11 @@ def train_model(
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
model.freeze() model.freeze()
if args.num_layers > len(model.layers):
raise ValueError(
f"Requested to train {args.num_layers} layers "
f"but the model only has {len(model.layers)} layers."
)
if args.fine_tune_type == "full": if args.fine_tune_type == "full":
for l in model.layers[-max(args.num_layers, 0) :]: for l in model.layers[-max(args.num_layers, 0) :]:
l.unfreeze() l.unfreeze()

View File

@ -6,6 +6,7 @@ import shutil
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple
from typing import Union from typing import Union
import mlx.core as mx import mlx.core as mx
@ -13,7 +14,7 @@ import mlx.nn as nn
import numpy as np import numpy as np
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
@ -70,14 +71,18 @@ class TrainingArgs:
) )
def default_loss(model, inputs, targets, lengths): def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs) logits = model(inputs)
logits = logits.astype(mx.float32) logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = length_mask.sum() ntoks = mask.sum()
ce = ce.sum() / ntoks ce = ce.sum() / ntoks
return ce, ntoks return ce, ntoks
@ -162,6 +167,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
indices = np.random.permutation(len(batch_idx)) indices = np.random.permutation(len(batch_idx))
for i in indices: for i in indices:
batch = [dataset[j] for j in batch_idx[i]] batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:
print( print(
@ -185,7 +194,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
) )
batch = mx.array(batch_arr) batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths) yield batch, mx.array(list(zip(offsets, lengths)))
if not train: if not train:
break break
@ -201,8 +210,8 @@ def evaluate(
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_batches,
): ):
all_losses = 0 all_losses = mx.array(0.0)
ntokens = 0 ntokens = mx.array(0)
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -283,6 +292,7 @@ def train(
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
trained_tokens = 0 trained_tokens = 0
train_time = 0
# Main training loop # Main training loop
start = time.perf_counter() start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
@ -295,10 +305,11 @@ def train(
train=True, train=True,
), ),
): ):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss # Report validation loss if needed, the first validation loss
# is always measured before any training. # is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter() tic = time.perf_counter()
val_loss = evaluate( val_loss = evaluate(
model=model, model=model,
dataset=val_dataset, dataset=val_dataset,
@ -309,7 +320,7 @@ def train(
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - tic
if rank == 0: if rank == 0:
print( print(
f"Iter {it}: " f"Iter {it}: "
@ -326,24 +337,23 @@ def train(
} }
training_callback.on_val_loss_report(val_info) training_callback.on_val_loss_report(val_info)
start = time.perf_counter() tic = time.perf_counter()
lvalue, toks = step(batch) lvalue, toks = step(batch)
losses += lvalue losses += lvalue
n_tokens += toks n_tokens += toks
steps += 1 steps += 1
mx.eval(state, losses, n_tokens) mx.eval(state, losses, n_tokens)
train_time += time.perf_counter() - tic
# Report training loss if needed # Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size() train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / train_time
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9 peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0: if rank == 0:
@ -372,7 +382,7 @@ def train(
losses = 0 losses = 0
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
start = time.perf_counter() train_time = 0
# Save adapter weights # Save adapter weights
if it % args.steps_per_save == 0: if it % args.steps_per_save == 0: