mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Merge remote-tracking branch 'upstream/main' into mitmul/add-plamo2-1b-support
This commit is contained in:
commit
675c322978
@ -121,7 +121,7 @@ if __name__ == "__main__":
|
|||||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
torch_weights = torch.load(torch_path / "pytorch_model.bin")
|
torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
|
||||||
print("[INFO] Converting")
|
print("[INFO] Converting")
|
||||||
mlx_weights = {
|
mlx_weights = {
|
||||||
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.21.0"
|
__version__ = "0.21.5"
|
||||||
|
@ -181,8 +181,14 @@ def train_model(
|
|||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
model.freeze()
|
model.freeze()
|
||||||
|
if args.num_layers > len(model.layers):
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested to train {args.num_layers} layers "
|
||||||
|
f"but the model only has {len(model.layers)} layers."
|
||||||
|
)
|
||||||
|
|
||||||
if args.fine_tune_type == "full":
|
if args.fine_tune_type == "full":
|
||||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
for l in model.layers[-max(args.num_layers, 0) :]:
|
||||||
l.unfreeze()
|
l.unfreeze()
|
||||||
elif args.fine_tune_type in ["lora", "dora"]:
|
elif args.fine_tune_type in ["lora", "dora"]:
|
||||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||||
|
@ -52,11 +52,6 @@ def linear_to_lora_layers(
|
|||||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||||
Default: ``False``
|
Default: ``False``
|
||||||
"""
|
"""
|
||||||
if num_layers > len(model.layers):
|
|
||||||
raise ValueError(
|
|
||||||
f"Requested {num_layers} LoRA layers "
|
|
||||||
f"but the model only has {len(model.layers)} layers."
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_lora(layer):
|
def to_lora(layer):
|
||||||
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
|
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
|
||||||
@ -154,7 +149,7 @@ def linear_to_lora_layers(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Lora does not support {model.model_type}")
|
raise ValueError(f"Lora does not support {model.model_type}")
|
||||||
|
|
||||||
for l in model.layers[-min(num_layers, 0) :]:
|
for l in model.layers[-max(num_layers, 0) :]:
|
||||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||||
if lora_layers:
|
if lora_layers:
|
||||||
l.update_modules(tree_unflatten(lora_layers))
|
l.update_modules(tree_unflatten(lora_layers))
|
||||||
|
@ -410,8 +410,7 @@ def speculative_generate_step(
|
|||||||
for processor in logits_processors:
|
for processor in logits_processors:
|
||||||
logits = processor(tokens, logits)
|
logits = processor(tokens, logits)
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||||
logprobs = logprobs.squeeze(0)
|
|
||||||
y = sampler(logprobs)
|
y = sampler(logprobs)
|
||||||
return y, logprobs
|
return y, logprobs
|
||||||
|
|
||||||
@ -430,16 +429,14 @@ def speculative_generate_step(
|
|||||||
prev_tokens = (
|
prev_tokens = (
|
||||||
mx.concat([prev_tokens, y]) if prev_tokens is not None else y
|
mx.concat([prev_tokens, y]) if prev_tokens is not None else y
|
||||||
)
|
)
|
||||||
y, logprobs = _process_and_sample(
|
y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :])
|
||||||
prev_tokens, logits[:, i : i + 1, :]
|
|
||||||
)
|
|
||||||
out_y.append(y)
|
out_y.append(y)
|
||||||
out_logprobs.append(logprobs)
|
out_logprobs.append(logprobs)
|
||||||
return mx.concatenate(out_y, axis=0), mx.concatenate(
|
return mx.concatenate(out_y, axis=0), mx.concatenate(
|
||||||
out_logprobs, axis=0
|
out_logprobs, axis=0
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return _process_and_sample(None, logits)
|
return _process_and_sample(None, logits.squeeze(0))
|
||||||
|
|
||||||
def _prefill(model, cache, y):
|
def _prefill(model, cache, y):
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
@ -477,13 +474,9 @@ def speculative_generate_step(
|
|||||||
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
||||||
draft_tokens = _draft_generate(draft_y, num_draft)
|
draft_tokens = _draft_generate(draft_y, num_draft)
|
||||||
if prev_tokens is not None:
|
if prev_tokens is not None:
|
||||||
prev_tokens = prev_tokens[
|
prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1]
|
||||||
: prev_tokens.size - draft_y.size - num_draft + 1
|
|
||||||
]
|
|
||||||
y = mx.concatenate([y, draft_tokens])
|
y = mx.concatenate([y, draft_tokens])
|
||||||
|
|
||||||
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
||||||
|
|
||||||
mx.eval(tokens, draft_tokens)
|
mx.eval(tokens, draft_tokens)
|
||||||
draft_tokens = draft_tokens.tolist()
|
draft_tokens = draft_tokens.tolist()
|
||||||
tokens = tokens.tolist()
|
tokens = tokens.tolist()
|
||||||
@ -515,8 +508,8 @@ def speculative_generate_step(
|
|||||||
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
||||||
)
|
)
|
||||||
|
|
||||||
if prev_tokens is not None and n < num_draft:
|
if prev_tokens is not None:
|
||||||
prev_tokens = prev_tokens[: -(num_draft - n)]
|
prev_tokens = prev_tokens[: -max(num_draft - n, 1)]
|
||||||
_rewind_cache(num_draft, n)
|
_rewind_cache(num_draft, n)
|
||||||
finally:
|
finally:
|
||||||
_rewind_cache(num_draft, n)
|
_rewind_cache(num_draft, n)
|
||||||
|
@ -8,7 +8,6 @@ import datasets
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
import numpy as np
|
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
@ -40,26 +39,21 @@ class TransformerLM(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def to_samples(context_size, dataset):
|
def to_samples(context_size, dataset):
|
||||||
tokens = dataset.size
|
|
||||||
window_size = context_size + 1 # include target
|
window_size = context_size + 1 # include target
|
||||||
samples = tokens - window_size + 1
|
samples = dataset.size // window_size
|
||||||
X = np.lib.stride_tricks.as_strided(
|
dataset = dataset[: samples * window_size]
|
||||||
dataset,
|
return mx.array(dataset.reshape(samples, -1))
|
||||||
shape=(samples, window_size),
|
|
||||||
strides=(dataset.itemsize, dataset.itemsize),
|
|
||||||
)
|
|
||||||
return X[:, :-1], X[:, 1:]
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(batch_size, context_size, dataset):
|
def iterate_batches(batch_size, context_size, dataset):
|
||||||
inputs, targets = to_samples(context_size, dataset)
|
inputs = to_samples(context_size, dataset)
|
||||||
s = 0
|
s = 0
|
||||||
while True:
|
while True:
|
||||||
if s == 0:
|
if s == 0:
|
||||||
# Reset permutation:
|
# Reset permutation:
|
||||||
perm = np.random.permutation(inputs.shape[0])
|
perm = mx.random.permutation(inputs.shape[0])
|
||||||
ids = perm[s : s + batch_size]
|
ids = perm[s : s + batch_size]
|
||||||
yield inputs[ids], targets[ids]
|
yield inputs[ids]
|
||||||
s += batch_size
|
s += batch_size
|
||||||
if s >= inputs.shape[0]:
|
if s >= inputs.shape[0]:
|
||||||
s = 0
|
s = 0
|
||||||
@ -84,45 +78,42 @@ def main(args):
|
|||||||
)
|
)
|
||||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||||
|
|
||||||
def loss_fn(model, x, y, reduce=True):
|
def loss_fn(model, inputs, reduction="mean"):
|
||||||
|
x, y = inputs[..., :-1], inputs[..., 1:]
|
||||||
logits = model(x)
|
logits = model(x)
|
||||||
losses = nn.losses.cross_entropy(logits, y)
|
return nn.losses.cross_entropy(logits, y, reduction=reduction)
|
||||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
|
||||||
|
|
||||||
optimizer = optim.AdamW(
|
optimizer = optim.AdamW(
|
||||||
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
||||||
)
|
)
|
||||||
|
|
||||||
def eval_fn(dataset):
|
def eval_fn(dataset):
|
||||||
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
inputs = to_samples(context_size, dataset)
|
||||||
loss = 0
|
loss = 0
|
||||||
for s in range(0, targets.shape[0], batch_size):
|
for s in range(0, inputs.shape[0], batch_size):
|
||||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum")
|
||||||
bx, by = map(mx.array, (bx, by))
|
loss += losses.item()
|
||||||
losses = loss_fn(model, bx, by, reduce=False)
|
return loss / (inputs.size - inputs.shape[0])
|
||||||
loss += mx.sum(losses).item()
|
|
||||||
return loss / len(targets)
|
|
||||||
|
|
||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def step(inputs, targets):
|
def step(inputs):
|
||||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
loss, grads = loss_and_grad_fn(model, inputs, targets)
|
loss, grads = loss_and_grad_fn(model, inputs)
|
||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
train_iterator = iterate_batches(batch_size, context_size, train)
|
||||||
losses = []
|
losses = []
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
for it, inputs in zip(range(args.num_iters), train_iterator):
|
||||||
inputs, targets = map(mx.array, (inputs, targets))
|
|
||||||
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
||||||
loss = step(inputs, targets)
|
loss = step(inputs)
|
||||||
mx.eval(state)
|
mx.eval(state)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
if (it + 1) % steps_per_report == 0:
|
if (it + 1) % steps_per_report == 0:
|
||||||
train_loss = np.mean(losses)
|
train_loss = sum(losses) / len(losses)
|
||||||
toc = time.perf_counter()
|
toc = time.perf_counter()
|
||||||
print(
|
print(
|
||||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||||
|
Loading…
Reference in New Issue
Block a user