From 1cbf5cdac7b081b09bb4f8a8cb4909ff9fdcf108 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 19 Feb 2025 06:22:51 -0800 Subject: [PATCH 1/4] use more standard window strategy (#1287) --- transformer_lm/main.py | 47 +++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/transformer_lm/main.py b/transformer_lm/main.py index dc725cbe..7ff5b73f 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -8,7 +8,6 @@ import datasets import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim -import numpy as np from mlx.utils import tree_flatten @@ -40,26 +39,21 @@ class TransformerLM(nn.Module): def to_samples(context_size, dataset): - tokens = dataset.size window_size = context_size + 1 # include target - samples = tokens - window_size + 1 - X = np.lib.stride_tricks.as_strided( - dataset, - shape=(samples, window_size), - strides=(dataset.itemsize, dataset.itemsize), - ) - return X[:, :-1], X[:, 1:] + samples = dataset.size // window_size + dataset = dataset[: samples * window_size] + return mx.array(dataset.reshape(samples, -1)) def iterate_batches(batch_size, context_size, dataset): - inputs, targets = to_samples(context_size, dataset) + inputs = to_samples(context_size, dataset) s = 0 while True: if s == 0: # Reset permutation: - perm = np.random.permutation(inputs.shape[0]) + perm = mx.random.permutation(inputs.shape[0]) ids = perm[s : s + batch_size] - yield inputs[ids], targets[ids] + yield inputs[ids] s += batch_size if s >= inputs.shape[0]: s = 0 @@ -84,45 +78,42 @@ def main(args): ) print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters") - def loss_fn(model, x, y, reduce=True): + def loss_fn(model, inputs, reduction="mean"): + x, y = inputs[..., :-1], inputs[..., 1:] logits = model(x) - losses = nn.losses.cross_entropy(logits, y) - return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2)) + return nn.losses.cross_entropy(logits, y, reduction=reduction) optimizer = optim.AdamW( learning_rate=args.learning_rate, weight_decay=args.weight_decay ) def eval_fn(dataset): - inputs, targets = map(mx.array, to_samples(context_size, dataset)) + inputs = to_samples(context_size, dataset) loss = 0 - for s in range(0, targets.shape[0], batch_size): - bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] - bx, by = map(mx.array, (bx, by)) - losses = loss_fn(model, bx, by, reduce=False) - loss += mx.sum(losses).item() - return loss / len(targets) + for s in range(0, inputs.shape[0], batch_size): + losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum") + loss += losses.item() + return loss / (inputs.size - inputs.shape[0]) state = [model.state, optimizer.state] @partial(mx.compile, inputs=state, outputs=state) - def step(inputs, targets): + def step(inputs): loss_and_grad_fn = nn.value_and_grad(model, loss_fn) - loss, grads = loss_and_grad_fn(model, inputs, targets) + loss, grads = loss_and_grad_fn(model, inputs) optimizer.update(model, grads) return loss train_iterator = iterate_batches(batch_size, context_size, train) losses = [] tic = time.perf_counter() - for it, (inputs, targets) in zip(range(args.num_iters), train_iterator): - inputs, targets = map(mx.array, (inputs, targets)) + for it, inputs in zip(range(args.num_iters), train_iterator): optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate - loss = step(inputs, targets) + loss = step(inputs) mx.eval(state) losses.append(loss.item()) if (it + 1) % steps_per_report == 0: - train_loss = np.mean(losses) + train_loss = sum(losses) / len(losses) toc = time.perf_counter() print( f"Iter {it + 1}: Train loss {train_loss:.3f}, " From 85669451d0e4cfb2370994bdcad38b190cfbb417 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Feb 2025 13:32:01 -0800 Subject: [PATCH 2/4] Fix num layers in fine tune (#1294) --- llms/mlx_lm/lora.py | 8 +++++++- llms/mlx_lm/tuner/utils.py | 7 +------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index abc5dfa9..def3b6dd 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -181,8 +181,14 @@ def train_model( training_callback: TrainingCallback = None, ): model.freeze() + if args.num_layers > len(model.layers): + raise ValueError( + f"Requested to train {args.num_layers} layers " + f"but the model only has {len(model.layers)} layers." + ) + if args.fine_tune_type == "full": - for l in model.layers[-min(args.num_layers, 0) :]: + for l in model.layers[-max(args.num_layers, 0) :]: l.unfreeze() elif args.fine_tune_type in ["lora", "dora"]: # Convert linear layers to lora/dora layers and unfreeze in the process diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 9f8d2925..f5df11e3 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -52,11 +52,6 @@ def linear_to_lora_layers( use_dora (bool): If True, uses DoRA instead of LoRA. Default: ``False`` """ - if num_layers > len(model.layers): - raise ValueError( - f"Requested {num_layers} LoRA layers " - f"but the model only has {len(model.layers)} layers." - ) def to_lora(layer): if isinstance(layer, (nn.Linear, nn.QuantizedLinear)): @@ -154,7 +149,7 @@ def linear_to_lora_layers( else: raise ValueError(f"Lora does not support {model.model_type}") - for l in model.layers[-min(num_layers, 0) :]: + for l in model.layers[-max(num_layers, 0) :]: lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] if lora_layers: l.update_modules(tree_unflatten(lora_layers)) From 3d793ecf6887512fd81f0f0c6bd156c06a6eaf88 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Feb 2025 15:55:55 -0800 Subject: [PATCH 3/4] Fix logits processor bugs with spec dec (#1291) * Fix logits processor bugs with spec dec * bump patch --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/utils.py | 19 ++++++------------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index b2f98e6f..89e6cd00 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.21.0" +__version__ = "0.21.5" diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 78a2e802..1fae76fa 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -409,8 +409,7 @@ def speculative_generate_step( for processor in logits_processors: logits = processor(tokens, logits) - logprobs = logits - mx.logsumexp(logits, keepdims=True) - logprobs = logprobs.squeeze(0) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) y = sampler(logprobs) return y, logprobs @@ -429,16 +428,14 @@ def speculative_generate_step( prev_tokens = ( mx.concat([prev_tokens, y]) if prev_tokens is not None else y ) - y, logprobs = _process_and_sample( - prev_tokens, logits[:, i : i + 1, :] - ) + y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :]) out_y.append(y) out_logprobs.append(logprobs) return mx.concatenate(out_y, axis=0), mx.concatenate( out_logprobs, axis=0 ) else: - return _process_and_sample(None, logits) + return _process_and_sample(None, logits.squeeze(0)) def _prefill(model, cache, y): while y.size > prefill_step_size: @@ -476,13 +473,9 @@ def speculative_generate_step( num_draft = min(max_tokens - ntoks, num_draft_tokens) draft_tokens = _draft_generate(draft_y, num_draft) if prev_tokens is not None: - prev_tokens = prev_tokens[ - : prev_tokens.size - draft_y.size - num_draft + 1 - ] + prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1] y = mx.concatenate([y, draft_tokens]) - tokens, logprobs = _step(model, model_cache, y, num_draft + 1) - mx.eval(tokens, draft_tokens) draft_tokens = draft_tokens.tolist() tokens = tokens.tolist() @@ -514,8 +507,8 @@ def speculative_generate_step( [mx.array(draft_tokens[-1:], mx.uint32), draft_y] ) - if prev_tokens is not None and n < num_draft: - prev_tokens = prev_tokens[: -(num_draft - n)] + if prev_tokens is not None: + prev_tokens = prev_tokens[: -max(num_draft - n, 1)] _rewind_cache(num_draft, n) finally: _rewind_cache(num_draft, n) From 09b641aaa74f9737f747b62ad8c628405e7e25be Mon Sep 17 00:00:00 2001 From: Usama Ahmed <53372259+0ssamaak0@users.noreply.github.com> Date: Sat, 22 Feb 2025 17:08:54 +0300 Subject: [PATCH 4/4] Fix FutureWarning in torch.load by setting weights_only=True (#1295) --- clip/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clip/convert.py b/clip/convert.py index 29bac22e..976d7494 100644 --- a/clip/convert.py +++ b/clip/convert.py @@ -121,7 +121,7 @@ if __name__ == "__main__": mlx_path.mkdir(parents=True, exist_ok=True) print("[INFO] Loading") - torch_weights = torch.load(torch_path / "pytorch_model.bin") + torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True) print("[INFO] Converting") mlx_weights = { k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()