update upstream

This commit is contained in:
paNikitin 2025-02-23 12:48:25 +03:00
parent 0f790c4c84
commit a2b61afd05
2 changed files with 40 additions and 16 deletions

View File

@ -79,6 +79,7 @@ def build_parser():
"--train",
action="store_true",
help="Do training",
default=None,
)
parser.add_argument(
"--data",
@ -94,6 +95,12 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)
parser.add_argument(
"--num-layers",
type=int,
@ -136,6 +143,7 @@ def build_parser():
"--test",
action="store_true",
help="Evaluate on the test set after training",
default=None,
)
parser.add_argument(
"--test-batches",
@ -157,6 +165,7 @@ def build_parser():
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument(
@ -176,6 +185,11 @@ def train_model(
training_callback: TrainingCallback = None,
):
model.freeze()
if args.num_layers > len(model.layers):
raise ValueError(
f"Requested to train {args.num_layers} layers "
f"but the model only has {len(model.layers)} layers."
)
if args.fine_tune_type == "full":
for l in model.layers[-max(args.num_layers, 0) :]:
l.unfreeze()

View File

@ -6,6 +6,7 @@ import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Union
import mlx.core as mx
@ -13,7 +14,7 @@ import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from mlx_lm.tokenizer_utils import TokenizerWrapper
@ -70,14 +71,18 @@ class TrainingArgs:
)
def default_loss(model, inputs, targets, lengths):
def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
@ -162,6 +167,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
indices = np.random.permutation(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
@ -185,7 +194,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
)
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
yield batch, mx.array(list(zip(offsets, lengths)))
if not train:
break
@ -201,8 +210,8 @@ def evaluate(
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
):
all_losses = 0
ntokens = 0
all_losses = mx.array(0.0)
ntokens = mx.array(0)
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -283,6 +292,7 @@ def train(
n_tokens = 0
steps = 0
trained_tokens = 0
train_time = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
@ -295,10 +305,11 @@ def train(
train=True,
),
):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
tic = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
@ -309,7 +320,7 @@ def train(
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
val_time = time.perf_counter() - tic
if rank == 0:
print(
f"Iter {it}: "
@ -326,24 +337,23 @@ def train(
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
tic = time.perf_counter()
lvalue, toks = step(batch)
losses += lvalue
n_tokens += toks
steps += 1
mx.eval(state, losses, n_tokens)
train_time += time.perf_counter() - tic
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / train_time
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
@ -372,7 +382,7 @@ def train(
losses = 0
n_tokens = 0
steps = 0
start = time.perf_counter()
train_time = 0
# Save adapter weights
if it % args.steps_per_save == 0: