mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Some improvements to LoRA (#528)
* set cache_limit * remove set cache_limit * cleanup * add gradient checkpointing * fix sort * mokey patch call for checkpoint * fix example config
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# The path to the local model directory or Hugging Face repo.
|
||||
model: "mlx_model"
|
||||
|
||||
# Whether or not to train (boolean)
|
||||
train: true
|
||||
|
||||
@@ -49,6 +48,9 @@ test_batches: 500
|
||||
# Maximum sequence length.
|
||||
max_seq_length: 2048
|
||||
|
||||
# Use gradient checkpointing to reduce memory use.
|
||||
grad_checkpoint: false
|
||||
|
||||
# LoRA parameters can only be specified in a config file
|
||||
lora_parameters:
|
||||
# The layer keys to apply LoRA to.
|
||||
|
@@ -145,7 +145,12 @@ def build_parser():
|
||||
default=None,
|
||||
help="A YAML configuration file with the training options",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
parser.add_argument(
|
||||
"--grad-checkpoint",
|
||||
action="store_true",
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -222,6 +227,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=args.adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
if args.train:
|
||||
print("Training")
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -10,6 +11,22 @@ import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
def grad_checkpoint(layer):
|
||||
"""
|
||||
Update all instances of type(layer) to use gradient checkpointing.
|
||||
"""
|
||||
fn = type(layer).__call__
|
||||
|
||||
def checkpointed_fn(model, *args, **kwargs):
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
model.update(params)
|
||||
return fn(model, *args, **kwargs)
|
||||
|
||||
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
type(layer).__call__ = checkpointed_fn
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs:
|
||||
lora_layers: int = field(
|
||||
@@ -40,6 +57,10 @@ class TrainingArgs:
|
||||
default="adapter.npz",
|
||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
||||
)
|
||||
grad_checkpoint: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
@@ -56,16 +77,19 @@ def default_loss(model, inputs, targets, lengths):
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
|
||||
# Make the batches:
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
|
||||
while True:
|
||||
# Shuffle indices
|
||||
indices = np.arange(len(dataset))
|
||||
indices = np.random.permutation(indices)
|
||||
# Collect batches from dataset
|
||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
# Encode batch
|
||||
batch = [
|
||||
tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
|
||||
]
|
||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
if max(lengths) > max_seq_length:
|
||||
@@ -75,8 +99,11 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
"Consider pre-splitting your data to save memory."
|
||||
)
|
||||
|
||||
# Pad to the max length
|
||||
max_length_in_batch = min(max(lengths), max_seq_length)
|
||||
# Pad to the nearest multiple of 8 or the maximum length
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
||||
|
||||
for j in range(batch_size):
|
||||
@@ -157,7 +184,21 @@ def train(
|
||||
# Create checkpoints directory if it does not exist
|
||||
adapter_path = checkpoints_path(args.adapter_file)
|
||||
|
||||
# Create value and grad function for loss
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(batch):
|
||||
# Forward and backward pass
|
||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||
|
||||
# Model update
|
||||
optimizer.update(model, grad)
|
||||
|
||||
return lvalue, toks
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
|
||||
losses = []
|
||||
@@ -175,13 +216,8 @@ def train(
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
# Forward and backward pass
|
||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||
|
||||
# Model update
|
||||
optimizer.update(model, grad)
|
||||
|
||||
mx.eval(model.parameters(), optimizer.state, lvalue)
|
||||
lvalue, toks = step(batch)
|
||||
mx.eval(state, lvalue, toks)
|
||||
|
||||
# Record loss
|
||||
losses.append(lvalue.item())
|
||||
@@ -196,12 +232,14 @@ def train(
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
f"Learning Rate {learning_rate:.3e}, "
|
||||
f"It/sec {it_sec:.3f}, "
|
||||
f"Tokens/sec {tokens_sec:.3f}, "
|
||||
f"Trained Tokens {trained_tokens}"
|
||||
f"Trained Tokens {trained_tokens}, "
|
||||
f"Peak mem {peak_mem:.3f} GB"
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
|
@@ -35,9 +35,6 @@ def linear_to_lora_layers(
|
||||
lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"]
|
||||
)
|
||||
|
||||
# If the lora_parameters are set, we assume the keys
|
||||
# are correct for the given model
|
||||
|
||||
keys = config.get("keys", None)
|
||||
if keys is not None:
|
||||
keys = set(keys)
|
||||
@@ -53,7 +50,7 @@ def linear_to_lora_layers(
|
||||
]:
|
||||
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
||||
if model.model_type == "mixtral":
|
||||
keys.add(["block_sparse_moe.gate"])
|
||||
keys.add("block_sparse_moe.gate")
|
||||
elif model.model_type == "olmo":
|
||||
keys = set(["att_proj"])
|
||||
elif model.model_type == "phi-msft":
|
||||
|
Reference in New Issue
Block a user