mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
custom data with lora
This commit is contained in:
110
lora/lora.py
110
lora/lora.py
@@ -16,7 +16,6 @@ from mlx.utils import tree_map, tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
from models import ModelArgs, Model, LoRALinear
|
||||
import wikisql
|
||||
|
||||
|
||||
def build_parser():
|
||||
@@ -52,6 +51,18 @@ def build_parser():
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
default="data/",
|
||||
help="Directory with {train, valid, test}.jsonl files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_layers",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of layers to fine-tune",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="Minibatch size.")
|
||||
parser.add_argument(
|
||||
"--iters", type=int, default=1000, help="Iterations to train for."
|
||||
@@ -59,7 +70,7 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--val_batches",
|
||||
type=int,
|
||||
default=100,
|
||||
default=25,
|
||||
help="Number of validation batches, -1 uses the entire validation set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -111,8 +122,15 @@ class Tokenizer:
|
||||
self._sep = "▁"
|
||||
assert self._model.vocab_size() == self._model.get_piece_size()
|
||||
|
||||
def encode(self, s: str) -> List[int]:
|
||||
return [self._model.bos_id(), *self._model.encode(s)]
|
||||
def encode(self, s: str, eos: bool = False) -> List[int]:
|
||||
toks = [self._model.bos_id(), *self._model.encode(s)]
|
||||
if eos:
|
||||
toks.append(self.eos_id)
|
||||
return toks
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._model.eos_id()
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
out = self._model.decode(t)
|
||||
@@ -125,6 +143,44 @@ class Tokenizer:
|
||||
return self._model.vocab_size()
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold lines from a jsonl file
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, key: str = "text"):
|
||||
if not path.exists():
|
||||
self._data = None
|
||||
else:
|
||||
with open(path, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
self._key = key
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx][self._key]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def load(args):
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||
)
|
||||
if args.train and len(valid) == 0:
|
||||
raise ValueError(
|
||||
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||
)
|
||||
if args.test and len(test) == 0:
|
||||
raise ValueError(
|
||||
"Test set not found or empty. Must provide test set for evaluation."
|
||||
)
|
||||
return train, valid, test
|
||||
|
||||
|
||||
def loss(model, inputs, targets, lengths):
|
||||
# Run model on inputs
|
||||
logits, _ = model(inputs)
|
||||
@@ -139,24 +195,31 @@ def loss(model, inputs, targets, lengths):
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def iterate_batches(dset, tokenizer, batch_size, shuffle=True):
|
||||
def iterate_batches(dset, tokenizer, batch_size, train=False):
|
||||
# Shuffle indices
|
||||
indices = np.arange(len(dset))
|
||||
if shuffle:
|
||||
indices = np.random.permutation(indices)
|
||||
while True:
|
||||
indices = np.arange(len(dset))
|
||||
if train:
|
||||
indices = np.random.permutation(indices)
|
||||
|
||||
# Collect batches from dataset
|
||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||
# Encode batch
|
||||
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
|
||||
lengths = [len(x) for x in batch]
|
||||
# Collect batches from dataset
|
||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||
# Encode batch
|
||||
batch = [
|
||||
tokenizer.encode(dset[indices[i + j]], eos=True)
|
||||
for j in range(batch_size)
|
||||
]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
# Pad to the max length
|
||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||
for j in range(batch_size):
|
||||
batch_arr[j, : lengths[j]] = batch[j]
|
||||
batch = mx.array(batch_arr)
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
# Pad to the max length
|
||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||
for j in range(batch_size):
|
||||
batch_arr[j, : lengths[j]] = batch[j]
|
||||
batch = mx.array(batch_arr)
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
|
||||
@@ -164,7 +227,7 @@ def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
|
||||
ntokens = 0
|
||||
for it, batch in zip(
|
||||
range(num_batches),
|
||||
iterate_batches(dataset, tokenizer, batch_size, shuffle=False),
|
||||
iterate_batches(dataset, tokenizer, batch_size),
|
||||
):
|
||||
losses, toks = loss(model, *batch)
|
||||
all_losses.append((losses * toks).item())
|
||||
@@ -183,7 +246,8 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(args.iters), iterate_batches(train_set, tokenizer, args.batch_size)
|
||||
range(args.iters),
|
||||
iterate_batches(train_set, tokenizer, args.batch_size, train=True),
|
||||
):
|
||||
# Forward and backward pass
|
||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||
@@ -289,7 +353,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.layers[16:32]:
|
||||
for l in model.layers[-args.lora_layers :]:
|
||||
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
|
||||
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
|
||||
|
||||
@@ -299,7 +363,7 @@ if __name__ == "__main__":
|
||||
print(f"Trainable parameters {p:.3f}M")
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = wikisql.load()
|
||||
train_set, valid_set, test_set = load(args)
|
||||
|
||||
# Resume training the given adapters.
|
||||
if args.resume_adapter_file is not None:
|
||||
|
Reference in New Issue
Block a user