mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:22:46 +08:00
update upstream
This commit is contained in:
parent
0f790c4c84
commit
a2b61afd05
@ -79,6 +79,7 @@ def build_parser():
|
|||||||
"--train",
|
"--train",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Do training",
|
help="Do training",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
@ -94,6 +95,12 @@ def build_parser():
|
|||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mask-prompt",
|
||||||
|
action="store_true",
|
||||||
|
help="Mask the prompt in the loss when training",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-layers",
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -136,6 +143,7 @@ def build_parser():
|
|||||||
"--test",
|
"--test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Evaluate on the test set after training",
|
help="Evaluate on the test set after training",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test-batches",
|
"--test-batches",
|
||||||
@ -157,6 +165,7 @@ def build_parser():
|
|||||||
"--grad-checkpoint",
|
"--grad-checkpoint",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use gradient checkpointing to reduce memory use.",
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -176,6 +185,11 @@ def train_model(
|
|||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
model.freeze()
|
model.freeze()
|
||||||
|
if args.num_layers > len(model.layers):
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested to train {args.num_layers} layers "
|
||||||
|
f"but the model only has {len(model.layers)} layers."
|
||||||
|
)
|
||||||
if args.fine_tune_type == "full":
|
if args.fine_tune_type == "full":
|
||||||
for l in model.layers[-max(args.num_layers, 0) :]:
|
for l in model.layers[-max(args.num_layers, 0) :]:
|
||||||
l.unfreeze()
|
l.unfreeze()
|
||||||
|
@ -6,6 +6,7 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -13,7 +14,7 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
|
||||||
@ -70,14 +71,18 @@ class TrainingArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def default_loss(model, batch, lengths):
|
||||||
|
inputs = batch[:, :-1]
|
||||||
|
targets = batch[:, 1:]
|
||||||
|
|
||||||
logits = model(inputs)
|
logits = model(inputs)
|
||||||
logits = logits.astype(mx.float32)
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
steps = mx.arange(1, targets.shape[1] + 1)
|
||||||
|
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||||
|
|
||||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||||
ntoks = length_mask.sum()
|
ntoks = mask.sum()
|
||||||
ce = ce.sum() / ntoks
|
ce = ce.sum() / ntoks
|
||||||
|
|
||||||
return ce, ntoks
|
return ce, ntoks
|
||||||
@ -162,6 +167,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
indices = np.random.permutation(len(batch_idx))
|
indices = np.random.permutation(len(batch_idx))
|
||||||
for i in indices:
|
for i in indices:
|
||||||
batch = [dataset[j] for j in batch_idx[i]]
|
batch = [dataset[j] for j in batch_idx[i]]
|
||||||
|
if len(batch[0]) == 2:
|
||||||
|
batch, offsets = zip(*batch)
|
||||||
|
else:
|
||||||
|
offsets = [0] * len(batch)
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
if max(lengths) > max_seq_length:
|
if max(lengths) > max_seq_length:
|
||||||
print(
|
print(
|
||||||
@ -185,7 +194,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
)
|
)
|
||||||
batch = mx.array(batch_arr)
|
batch = mx.array(batch_arr)
|
||||||
|
|
||||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
@ -201,8 +210,8 @@ def evaluate(
|
|||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = 0
|
all_losses = mx.array(0.0)
|
||||||
ntokens = 0
|
ntokens = mx.array(0)
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
@ -283,6 +292,7 @@ def train(
|
|||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
trained_tokens = 0
|
trained_tokens = 0
|
||||||
|
train_time = 0
|
||||||
# Main training loop
|
# Main training loop
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
for it, batch in zip(
|
for it, batch in zip(
|
||||||
@ -295,10 +305,11 @@ def train(
|
|||||||
train=True,
|
train=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
tic = time.perf_counter()
|
||||||
# Report validation loss if needed, the first validation loss
|
# Report validation loss if needed, the first validation loss
|
||||||
# is always measured before any training.
|
# is always measured before any training.
|
||||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
val_loss = evaluate(
|
val_loss = evaluate(
|
||||||
model=model,
|
model=model,
|
||||||
dataset=val_dataset,
|
dataset=val_dataset,
|
||||||
@ -309,7 +320,7 @@ def train(
|
|||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
iterate_batches=iterate_batches,
|
iterate_batches=iterate_batches,
|
||||||
)
|
)
|
||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - tic
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: "
|
f"Iter {it}: "
|
||||||
@ -326,24 +337,23 @@ def train(
|
|||||||
}
|
}
|
||||||
training_callback.on_val_loss_report(val_info)
|
training_callback.on_val_loss_report(val_info)
|
||||||
|
|
||||||
start = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
|
|
||||||
lvalue, toks = step(batch)
|
lvalue, toks = step(batch)
|
||||||
losses += lvalue
|
losses += lvalue
|
||||||
n_tokens += toks
|
n_tokens += toks
|
||||||
steps += 1
|
steps += 1
|
||||||
mx.eval(state, losses, n_tokens)
|
mx.eval(state, losses, n_tokens)
|
||||||
|
train_time += time.perf_counter() - tic
|
||||||
|
|
||||||
# Report training loss if needed
|
# Report training loss if needed
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
|
||||||
|
|
||||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||||
train_loss /= steps * mx.distributed.init().size()
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / train_time
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / train_time
|
||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -372,7 +382,7 @@ def train(
|
|||||||
losses = 0
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
start = time.perf_counter()
|
train_time = 0
|
||||||
|
|
||||||
# Save adapter weights
|
# Save adapter weights
|
||||||
if it % args.steps_per_save == 0:
|
if it % args.steps_per_save == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user