mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Some improvements to LoRA (#528)
* set cache_limit * remove set cache_limit * cleanup * add gradient checkpointing * fix sort * mokey patch call for checkpoint * fix example config
This commit is contained in:
parent
e56d9015ef
commit
39084e81c2
@ -1,6 +1,5 @@
|
|||||||
# The path to the local model directory or Hugging Face repo.
|
# The path to the local model directory or Hugging Face repo.
|
||||||
model: "mlx_model"
|
model: "mlx_model"
|
||||||
|
|
||||||
# Whether or not to train (boolean)
|
# Whether or not to train (boolean)
|
||||||
train: true
|
train: true
|
||||||
|
|
||||||
@ -49,6 +48,9 @@ test_batches: 500
|
|||||||
# Maximum sequence length.
|
# Maximum sequence length.
|
||||||
max_seq_length: 2048
|
max_seq_length: 2048
|
||||||
|
|
||||||
|
# Use gradient checkpointing to reduce memory use.
|
||||||
|
grad_checkpoint: false
|
||||||
|
|
||||||
# LoRA parameters can only be specified in a config file
|
# LoRA parameters can only be specified in a config file
|
||||||
lora_parameters:
|
lora_parameters:
|
||||||
# The layer keys to apply LoRA to.
|
# The layer keys to apply LoRA to.
|
||||||
|
@ -145,7 +145,12 @@ def build_parser():
|
|||||||
default=None,
|
default=None,
|
||||||
help="A YAML configuration file with the training options",
|
help="A YAML configuration file with the training options",
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
parser.add_argument(
|
||||||
|
"--grad-checkpoint",
|
||||||
|
action="store_true",
|
||||||
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -222,6 +227,7 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
steps_per_save=args.save_every,
|
steps_per_save=args.save_every,
|
||||||
adapter_file=args.adapter_file,
|
adapter_file=args.adapter_file,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
)
|
)
|
||||||
if args.train:
|
if args.train:
|
||||||
print("Training")
|
print("Training")
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -10,6 +11,22 @@ import numpy as np
|
|||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
|
def grad_checkpoint(layer):
|
||||||
|
"""
|
||||||
|
Update all instances of type(layer) to use gradient checkpointing.
|
||||||
|
"""
|
||||||
|
fn = type(layer).__call__
|
||||||
|
|
||||||
|
def checkpointed_fn(model, *args, **kwargs):
|
||||||
|
def inner_fn(params, *args, **kwargs):
|
||||||
|
model.update(params)
|
||||||
|
return fn(model, *args, **kwargs)
|
||||||
|
|
||||||
|
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
|
||||||
|
|
||||||
|
type(layer).__call__ = checkpointed_fn
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArgs:
|
class TrainingArgs:
|
||||||
lora_layers: int = field(
|
lora_layers: int = field(
|
||||||
@ -40,6 +57,10 @@ class TrainingArgs:
|
|||||||
default="adapter.npz",
|
default="adapter.npz",
|
||||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
metadata={"help": "Save/load path for the trained adapter weights."},
|
||||||
)
|
)
|
||||||
|
grad_checkpoint: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def default_loss(model, inputs, targets, lengths):
|
||||||
@ -56,16 +77,19 @@ def default_loss(model, inputs, targets, lengths):
|
|||||||
|
|
||||||
|
|
||||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
|
# Sort by length:
|
||||||
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||||
|
|
||||||
|
# Make the batches:
|
||||||
|
batch_idx = [
|
||||||
|
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Shuffle indices
|
indices = np.random.permutation(len(batch_idx))
|
||||||
indices = np.arange(len(dataset))
|
for i in indices:
|
||||||
indices = np.random.permutation(indices)
|
|
||||||
# Collect batches from dataset
|
|
||||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
|
||||||
# Encode batch
|
# Encode batch
|
||||||
batch = [
|
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||||
tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
|
|
||||||
]
|
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
|
|
||||||
if max(lengths) > max_seq_length:
|
if max(lengths) > max_seq_length:
|
||||||
@ -75,8 +99,11 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
"Consider pre-splitting your data to save memory."
|
"Consider pre-splitting your data to save memory."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pad to the max length
|
# Pad to the nearest multiple of 8 or the maximum length
|
||||||
max_length_in_batch = min(max(lengths), max_seq_length)
|
pad_to = 8
|
||||||
|
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||||
|
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||||
|
|
||||||
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
||||||
|
|
||||||
for j in range(batch_size):
|
for j in range(batch_size):
|
||||||
@ -157,7 +184,21 @@ def train(
|
|||||||
# Create checkpoints directory if it does not exist
|
# Create checkpoints directory if it does not exist
|
||||||
adapter_path = checkpoints_path(args.adapter_file)
|
adapter_path = checkpoints_path(args.adapter_file)
|
||||||
|
|
||||||
# Create value and grad function for loss
|
if args.grad_checkpoint:
|
||||||
|
grad_checkpoint(model.layers[0])
|
||||||
|
|
||||||
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def step(batch):
|
||||||
|
# Forward and backward pass
|
||||||
|
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||||
|
|
||||||
|
# Model update
|
||||||
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
|
return lvalue, toks
|
||||||
|
|
||||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
@ -175,13 +216,8 @@ def train(
|
|||||||
train=True,
|
train=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Forward and backward pass
|
lvalue, toks = step(batch)
|
||||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
mx.eval(state, lvalue, toks)
|
||||||
|
|
||||||
# Model update
|
|
||||||
optimizer.update(model, grad)
|
|
||||||
|
|
||||||
mx.eval(model.parameters(), optimizer.state, lvalue)
|
|
||||||
|
|
||||||
# Record loss
|
# Record loss
|
||||||
losses.append(lvalue.item())
|
losses.append(lvalue.item())
|
||||||
@ -196,12 +232,14 @@ def train(
|
|||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
print(
|
print(
|
||||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||||
f"Learning Rate {learning_rate:.3e}, "
|
f"Learning Rate {learning_rate:.3e}, "
|
||||||
f"It/sec {it_sec:.3f}, "
|
f"It/sec {it_sec:.3f}, "
|
||||||
f"Tokens/sec {tokens_sec:.3f}, "
|
f"Tokens/sec {tokens_sec:.3f}, "
|
||||||
f"Trained Tokens {trained_tokens}"
|
f"Trained Tokens {trained_tokens}, "
|
||||||
|
f"Peak mem {peak_mem:.3f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
|
@ -35,9 +35,6 @@ def linear_to_lora_layers(
|
|||||||
lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"]
|
lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the lora_parameters are set, we assume the keys
|
|
||||||
# are correct for the given model
|
|
||||||
|
|
||||||
keys = config.get("keys", None)
|
keys = config.get("keys", None)
|
||||||
if keys is not None:
|
if keys is not None:
|
||||||
keys = set(keys)
|
keys = set(keys)
|
||||||
@ -53,7 +50,7 @@ def linear_to_lora_layers(
|
|||||||
]:
|
]:
|
||||||
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
||||||
if model.model_type == "mixtral":
|
if model.model_type == "mixtral":
|
||||||
keys.add(["block_sparse_moe.gate"])
|
keys.add("block_sparse_moe.gate")
|
||||||
elif model.model_type == "olmo":
|
elif model.model_type == "olmo":
|
||||||
keys = set(["att_proj"])
|
keys = set(["att_proj"])
|
||||||
elif model.model_type == "phi-msft":
|
elif model.model_type == "phi-msft":
|
||||||
|
Loading…
Reference in New Issue
Block a user