update upstream

This commit is contained in:
paNikitin
2025-02-23 12:48:25 +03:00
parent 0f790c4c84
commit a2b61afd05
2 changed files with 40 additions and 16 deletions

View File

@@ -6,6 +6,7 @@ import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Union
import mlx.core as mx
@@ -13,7 +14,7 @@ import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -70,14 +71,18 @@ class TrainingArgs:
)
def default_loss(model, inputs, targets, lengths):
def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
@@ -162,6 +167,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
indices = np.random.permutation(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
@@ -185,7 +194,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
)
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
yield batch, mx.array(list(zip(offsets, lengths)))
if not train:
break
@@ -201,8 +210,8 @@ def evaluate(
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
):
all_losses = 0
ntokens = 0
all_losses = mx.array(0.0)
ntokens = mx.array(0)
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@@ -283,6 +292,7 @@ def train(
n_tokens = 0
steps = 0
trained_tokens = 0
train_time = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
@@ -295,10 +305,11 @@ def train(
train=True,
),
):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
tic = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
@@ -309,7 +320,7 @@ def train(
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
val_time = time.perf_counter() - tic
if rank == 0:
print(
f"Iter {it}: "
@@ -326,24 +337,23 @@ def train(
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
tic = time.perf_counter()
lvalue, toks = step(batch)
losses += lvalue
n_tokens += toks
steps += 1
mx.eval(state, losses, n_tokens)
train_time += time.perf_counter() - tic
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / train_time
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
@@ -372,7 +382,7 @@ def train(
losses = 0
n_tokens = 0
steps = 0
start = time.perf_counter()
train_time = 0
# Save adapter weights
if it % args.steps_per_save == 0: