Merge remote-tracking branch 'upstream/main' into mitmul/add-plamo2-1b-support

This commit is contained in:
Shunta Saito 2025-02-24 13:37:43 +09:00
commit 675c322978
6 changed files with 35 additions and 50 deletions

View File

@ -121,7 +121,7 @@ if __name__ == "__main__":
mlx_path.mkdir(parents=True, exist_ok=True) mlx_path.mkdir(parents=True, exist_ok=True)
print("[INFO] Loading") print("[INFO] Loading")
torch_weights = torch.load(torch_path / "pytorch_model.bin") torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
print("[INFO] Converting") print("[INFO] Converting")
mlx_weights = { mlx_weights = {
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items() k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.21.0" __version__ = "0.21.5"

View File

@ -181,8 +181,14 @@ def train_model(
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
model.freeze() model.freeze()
if args.num_layers > len(model.layers):
raise ValueError(
f"Requested to train {args.num_layers} layers "
f"but the model only has {len(model.layers)} layers."
)
if args.fine_tune_type == "full": if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]: for l in model.layers[-max(args.num_layers, 0) :]:
l.unfreeze() l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]: elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process # Convert linear layers to lora/dora layers and unfreeze in the process

View File

@ -52,11 +52,6 @@ def linear_to_lora_layers(
use_dora (bool): If True, uses DoRA instead of LoRA. use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False`` Default: ``False``
""" """
if num_layers > len(model.layers):
raise ValueError(
f"Requested {num_layers} LoRA layers "
f"but the model only has {len(model.layers)} layers."
)
def to_lora(layer): def to_lora(layer):
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)): if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
@ -154,7 +149,7 @@ def linear_to_lora_layers(
else: else:
raise ValueError(f"Lora does not support {model.model_type}") raise ValueError(f"Lora does not support {model.model_type}")
for l in model.layers[-min(num_layers, 0) :]: for l in model.layers[-max(num_layers, 0) :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
if lora_layers: if lora_layers:
l.update_modules(tree_unflatten(lora_layers)) l.update_modules(tree_unflatten(lora_layers))

View File

@ -410,8 +410,7 @@ def speculative_generate_step(
for processor in logits_processors: for processor in logits_processors:
logits = processor(tokens, logits) logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
logprobs = logprobs.squeeze(0)
y = sampler(logprobs) y = sampler(logprobs)
return y, logprobs return y, logprobs
@ -430,16 +429,14 @@ def speculative_generate_step(
prev_tokens = ( prev_tokens = (
mx.concat([prev_tokens, y]) if prev_tokens is not None else y mx.concat([prev_tokens, y]) if prev_tokens is not None else y
) )
y, logprobs = _process_and_sample( y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :])
prev_tokens, logits[:, i : i + 1, :]
)
out_y.append(y) out_y.append(y)
out_logprobs.append(logprobs) out_logprobs.append(logprobs)
return mx.concatenate(out_y, axis=0), mx.concatenate( return mx.concatenate(out_y, axis=0), mx.concatenate(
out_logprobs, axis=0 out_logprobs, axis=0
) )
else: else:
return _process_and_sample(None, logits) return _process_and_sample(None, logits.squeeze(0))
def _prefill(model, cache, y): def _prefill(model, cache, y):
while y.size > prefill_step_size: while y.size > prefill_step_size:
@ -477,13 +474,9 @@ def speculative_generate_step(
num_draft = min(max_tokens - ntoks, num_draft_tokens) num_draft = min(max_tokens - ntoks, num_draft_tokens)
draft_tokens = _draft_generate(draft_y, num_draft) draft_tokens = _draft_generate(draft_y, num_draft)
if prev_tokens is not None: if prev_tokens is not None:
prev_tokens = prev_tokens[ prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1]
: prev_tokens.size - draft_y.size - num_draft + 1
]
y = mx.concatenate([y, draft_tokens]) y = mx.concatenate([y, draft_tokens])
tokens, logprobs = _step(model, model_cache, y, num_draft + 1) tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
mx.eval(tokens, draft_tokens) mx.eval(tokens, draft_tokens)
draft_tokens = draft_tokens.tolist() draft_tokens = draft_tokens.tolist()
tokens = tokens.tolist() tokens = tokens.tolist()
@ -515,8 +508,8 @@ def speculative_generate_step(
[mx.array(draft_tokens[-1:], mx.uint32), draft_y] [mx.array(draft_tokens[-1:], mx.uint32), draft_y]
) )
if prev_tokens is not None and n < num_draft: if prev_tokens is not None:
prev_tokens = prev_tokens[: -(num_draft - n)] prev_tokens = prev_tokens[: -max(num_draft - n, 1)]
_rewind_cache(num_draft, n) _rewind_cache(num_draft, n)
finally: finally:
_rewind_cache(num_draft, n) _rewind_cache(num_draft, n)

View File

@ -8,7 +8,6 @@ import datasets
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
@ -40,26 +39,21 @@ class TransformerLM(nn.Module):
def to_samples(context_size, dataset): def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target window_size = context_size + 1 # include target
samples = tokens - window_size + 1 samples = dataset.size // window_size
X = np.lib.stride_tricks.as_strided( dataset = dataset[: samples * window_size]
dataset, return mx.array(dataset.reshape(samples, -1))
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset): def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset) inputs = to_samples(context_size, dataset)
s = 0 s = 0
while True: while True:
if s == 0: if s == 0:
# Reset permutation: # Reset permutation:
perm = np.random.permutation(inputs.shape[0]) perm = mx.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size] ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids] yield inputs[ids]
s += batch_size s += batch_size
if s >= inputs.shape[0]: if s >= inputs.shape[0]:
s = 0 s = 0
@ -84,45 +78,42 @@ def main(args):
) )
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters") print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y, reduce=True): def loss_fn(model, inputs, reduction="mean"):
x, y = inputs[..., :-1], inputs[..., 1:]
logits = model(x) logits = model(x)
losses = nn.losses.cross_entropy(logits, y) return nn.losses.cross_entropy(logits, y, reduction=reduction)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
optimizer = optim.AdamW( optimizer = optim.AdamW(
learning_rate=args.learning_rate, weight_decay=args.weight_decay learning_rate=args.learning_rate, weight_decay=args.weight_decay
) )
def eval_fn(dataset): def eval_fn(dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset)) inputs = to_samples(context_size, dataset)
loss = 0 loss = 0
for s in range(0, targets.shape[0], batch_size): for s in range(0, inputs.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum")
bx, by = map(mx.array, (bx, by)) loss += losses.item()
losses = loss_fn(model, bx, by, reduce=False) return loss / (inputs.size - inputs.shape[0])
loss += mx.sum(losses).item()
return loss / len(targets)
state = [model.state, optimizer.state] state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets): def step(inputs):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets) loss, grads = loss_and_grad_fn(model, inputs)
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss
train_iterator = iterate_batches(batch_size, context_size, train) train_iterator = iterate_batches(batch_size, context_size, train)
losses = [] losses = []
tic = time.perf_counter() tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator): for it, inputs in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
loss = step(inputs, targets) loss = step(inputs)
mx.eval(state) mx.eval(state)
losses.append(loss.item()) losses.append(loss.item())
if (it + 1) % steps_per_report == 0: if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses) train_loss = sum(losses) / len(losses)
toc = time.perf_counter() toc = time.perf_counter()
print( print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, " f"Iter {it + 1}: Train loss {train_loss:.3f}, "