mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
update upstream
This commit is contained in:
parent
0f790c4c84
commit
a2b61afd05
@ -79,6 +79,7 @@ def build_parser():
|
||||
"--train",
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
@ -94,6 +95,12 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-prompt",
|
||||
action="store_true",
|
||||
help="Mask the prompt in the loss when training",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
@ -136,6 +143,7 @@ def build_parser():
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batches",
|
||||
@ -157,6 +165,7 @@ def build_parser():
|
||||
"--grad-checkpoint",
|
||||
action="store_true",
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
parser.add_argument(
|
||||
@ -176,6 +185,11 @@ def train_model(
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
model.freeze()
|
||||
if args.num_layers > len(model.layers):
|
||||
raise ValueError(
|
||||
f"Requested to train {args.num_layers} layers "
|
||||
f"but the model only has {len(model.layers)} layers."
|
||||
)
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-max(args.num_layers, 0) :]:
|
||||
l.unfreeze()
|
||||
|
@ -6,6 +6,7 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Union
|
||||
|
||||
import mlx.core as mx
|
||||
@ -13,7 +14,7 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
|
||||
@ -70,14 +71,18 @@ class TrainingArgs:
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
def default_loss(model, batch, lengths):
|
||||
inputs = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
steps = mx.arange(1, targets.shape[1] + 1)
|
||||
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||
ntoks = length_mask.sum()
|
||||
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||
ntoks = mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
|
||||
return ce, ntoks
|
||||
@ -162,6 +167,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
if len(batch[0]) == 2:
|
||||
batch, offsets = zip(*batch)
|
||||
else:
|
||||
offsets = [0] * len(batch)
|
||||
lengths = [len(x) for x in batch]
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
@ -185,7 +194,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
)
|
||||
batch = mx.array(batch_arr)
|
||||
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||
|
||||
if not train:
|
||||
break
|
||||
@ -201,8 +210,8 @@ def evaluate(
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
):
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
all_losses = mx.array(0.0)
|
||||
ntokens = mx.array(0)
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
@ -283,6 +292,7 @@ def train(
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
train_time = 0
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
@ -295,10 +305,11 @@ def train(
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
tic = time.perf_counter()
|
||||
# Report validation loss if needed, the first validation loss
|
||||
# is always measured before any training.
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
tic = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
@ -309,7 +320,7 @@ def train(
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
val_time = time.perf_counter() - tic
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: "
|
||||
@ -326,24 +337,23 @@ def train(
|
||||
}
|
||||
training_callback.on_val_loss_report(val_info)
|
||||
|
||||
start = time.perf_counter()
|
||||
tic = time.perf_counter()
|
||||
|
||||
lvalue, toks = step(batch)
|
||||
losses += lvalue
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
mx.eval(state, losses, n_tokens)
|
||||
train_time += time.perf_counter() - tic
|
||||
|
||||
# Report training loss if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
it_sec = args.steps_per_report / train_time
|
||||
tokens_sec = float(n_tokens) / train_time
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
if rank == 0:
|
||||
@ -372,7 +382,7 @@ def train(
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
start = time.perf_counter()
|
||||
train_time = 0
|
||||
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
|
Loading…
Reference in New Issue
Block a user