custom data with lora

This commit is contained in:
Awni Hannun
2023-12-15 09:56:10 -08:00
parent a3ecda22fe
commit 985f413f99
7 changed files with 1356 additions and 25 deletions

View File

@@ -16,7 +16,6 @@ from mlx.utils import tree_map, tree_flatten, tree_unflatten
from models import ModelArgs, Model, LoRALinear
import wikisql
def build_parser():
@@ -52,6 +51,18 @@ def build_parser():
action="store_true",
help="Do training",
)
parser.add_argument(
"--data",
type=str,
default="data/",
help="Directory with {train, valid, test}.jsonl files",
)
parser.add_argument(
"--lora_layers",
type=int,
default=16,
help="Number of layers to fine-tune",
)
parser.add_argument("--batch_size", type=int, default=4, help="Minibatch size.")
parser.add_argument(
"--iters", type=int, default=1000, help="Iterations to train for."
@@ -59,7 +70,7 @@ def build_parser():
parser.add_argument(
"--val_batches",
type=int,
default=100,
default=25,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument(
@@ -111,8 +122,15 @@ class Tokenizer:
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
def encode(self, s: str) -> List[int]:
return [self._model.bos_id(), *self._model.encode(s)]
def encode(self, s: str, eos: bool = False) -> List[int]:
toks = [self._model.bos_id(), *self._model.encode(s)]
if eos:
toks.append(self.eos_id)
return toks
@property
def eos_id(self) -> int:
return self._model.eos_id()
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
@@ -125,6 +143,44 @@ class Tokenizer:
return self._model.vocab_size()
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
"""
def __init__(self, path: Path, key: str = "text"):
if not path.exists():
self._data = None
else:
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
self._key = key
def __getitem__(self, idx: int):
return self._data[idx][self._key]
def __len__(self):
return len(self._data)
def load(args):
names = ("train", "valid", "test")
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."
)
if args.train and len(valid) == 0:
raise ValueError(
"Validation set not found or empty. Must provide validation set for fine-tuning."
)
if args.test and len(test) == 0:
raise ValueError(
"Test set not found or empty. Must provide test set for evaluation."
)
return train, valid, test
def loss(model, inputs, targets, lengths):
# Run model on inputs
logits, _ = model(inputs)
@@ -139,24 +195,31 @@ def loss(model, inputs, targets, lengths):
return ce, ntoks
def iterate_batches(dset, tokenizer, batch_size, shuffle=True):
def iterate_batches(dset, tokenizer, batch_size, train=False):
# Shuffle indices
indices = np.arange(len(dset))
if shuffle:
indices = np.random.permutation(indices)
while True:
indices = np.arange(len(dset))
if train:
indices = np.random.permutation(indices)
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
lengths = [len(x) for x in batch]
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch
batch = [
tokenizer.encode(dset[indices[i + j]], eos=True)
for j in range(batch_size)
]
lengths = [len(x) for x in batch]
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
for j in range(batch_size):
batch_arr[j, : lengths[j]] = batch[j]
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
for j in range(batch_size):
batch_arr[j, : lengths[j]] = batch[j]
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train:
break
def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
@@ -164,7 +227,7 @@ def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
ntokens = 0
for it, batch in zip(
range(num_batches),
iterate_batches(dataset, tokenizer, batch_size, shuffle=False),
iterate_batches(dataset, tokenizer, batch_size),
):
losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item())
@@ -183,7 +246,8 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(args.iters), iterate_batches(train_set, tokenizer, args.batch_size)
range(args.iters),
iterate_batches(train_set, tokenizer, args.batch_size, train=True),
):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
@@ -289,7 +353,7 @@ if __name__ == "__main__":
# Freeze all layers other than LORA linears
model.freeze()
for l in model.layers[16:32]:
for l in model.layers[-args.lora_layers :]:
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
@@ -299,7 +363,7 @@ if __name__ == "__main__":
print(f"Trainable parameters {p:.3f}M")
print("Loading datasets")
train_set, valid_set, test_set = wikisql.load()
train_set, valid_set, test_set = load(args)
# Resume training the given adapters.
if args.resume_adapter_file is not None: