| 
									
										
										
										
											2024-03-23 07:13:51 -07:00
										 |  |  | # Copyright © 2023-2024 Apple Inc. | 
					
						
							| 
									
										
										
										
											2023-11-30 11:08:53 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | import argparse | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | import json | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | import math | 
					
						
							|  |  |  | import time | 
					
						
							| 
									
										
										
										
											2023-12-20 10:22:25 -08:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import mlx.core as mx | 
					
						
							|  |  |  | import mlx.nn as nn | 
					
						
							|  |  |  | import mlx.optimizers as optim | 
					
						
							| 
									
										
										
										
											2023-12-20 10:22:25 -08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2024-01-12 13:45:30 -08:00
										 |  |  | import utils as lora_utils | 
					
						
							| 
									
										
										
										
											2024-04-16 10:50:32 -04:00
										 |  |  | from mlx.utils import tree_flatten | 
					
						
							| 
									
										
										
										
											2024-03-23 07:13:51 -07:00
										 |  |  | from models import LoRALinear | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def build_parser(): | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  |     parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2023-12-09 14:15:25 -08:00
										 |  |  |         "--model", | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  |         default="mlx_model", | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |         help="The path to the local model directory or Hugging Face repo.", | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     ) | 
					
						
							|  |  |  |     # Generation args | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2024-01-09 21:41:12 -08:00
										 |  |  |         "--max-tokens", | 
					
						
							|  |  |  |         "-m", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=100, | 
					
						
							|  |  |  |         help="The maximum number of tokens to generate", | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--temp", type=float, default=0.8, help="The sampling temperature" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--prompt", | 
					
						
							|  |  |  |         "-p", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         help="The prompt for generation", | 
					
						
							|  |  |  |         default=None, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Training args | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--train", | 
					
						
							|  |  |  |         action="store_true", | 
					
						
							|  |  |  |         help="Do training", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-05-14 08:17:42 +08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--add-eos-token", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=1, | 
					
						
							|  |  |  |         help="Enable add_eos_token for tokenizer", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--data", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="data/", | 
					
						
							|  |  |  |         help="Directory with {train, valid, test}.jsonl files", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |         "--lora-layers", | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |         type=int, | 
					
						
							|  |  |  |         default=16, | 
					
						
							|  |  |  |         help="Number of layers to fine-tune", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |     parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.") | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--iters", type=int, default=1000, help="Iterations to train for." | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |         "--val-batches", | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |         type=int, | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |         default=25, | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |         help="Number of validation batches, -1 uses the entire validation set.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |         "--learning-rate", type=float, default=1e-5, help="Adam learning rate." | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |         "--steps-per-report", | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |         type=int, | 
					
						
							|  |  |  |         default=10, | 
					
						
							|  |  |  |         help="Number of training steps between loss reporting.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |         "--steps-per-eval", | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |         type=int, | 
					
						
							|  |  |  |         default=200, | 
					
						
							|  |  |  |         help="Number of training steps between validations.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |         "--resume-adapter-file", | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  |         type=str, | 
					
						
							|  |  |  |         default=None, | 
					
						
							|  |  |  |         help="Load path to resume training with the given adapter weights.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |         "--adapter-file", | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |         type=str, | 
					
						
							|  |  |  |         default="adapters.npz", | 
					
						
							|  |  |  |         help="Save/load path for the trained adapter weights.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-01-17 05:03:33 +01:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--save-every", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=100, | 
					
						
							|  |  |  |         help="Save the model every N iterations.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--test", | 
					
						
							|  |  |  |         action="store_true", | 
					
						
							|  |  |  |         help="Evaluate on the test set after training", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2023-12-15 10:06:14 -08:00
										 |  |  |         "--test-batches", | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |         type=int, | 
					
						
							|  |  |  |         default=500, | 
					
						
							|  |  |  |         help="Number of test set batches, -1 uses the entire test set.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") | 
					
						
							|  |  |  |     return parser | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  | class Dataset: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Light-weight wrapper to hold lines from a jsonl file | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, path: Path, key: str = "text"): | 
					
						
							|  |  |  |         if not path.exists(): | 
					
						
							|  |  |  |             self._data = None | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             with open(path, "r") as fid: | 
					
						
							|  |  |  |                 self._data = [json.loads(l) for l in fid] | 
					
						
							|  |  |  |         self._key = key | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __getitem__(self, idx: int): | 
					
						
							|  |  |  |         return self._data[idx][self._key] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __len__(self): | 
					
						
							|  |  |  |         return len(self._data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def load(args): | 
					
						
							| 
									
										
										
										
											2024-02-20 05:11:45 +01:00
										 |  |  |     def load_and_check(name): | 
					
						
							| 
									
										
										
										
											2024-02-19 20:37:15 -08:00
										 |  |  |         dataset_path = Path(args.data) / f"{name}.jsonl" | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-02-20 12:53:30 -08:00
										 |  |  |             return Dataset(dataset_path) | 
					
						
							| 
									
										
										
										
											2024-02-19 20:37:15 -08:00
										 |  |  |         except Exception as e: | 
					
						
							|  |  |  |             print(f"Unable to build dataset {dataset_path} ({e})") | 
					
						
							|  |  |  |             raise | 
					
						
							| 
									
										
										
										
											2024-02-20 05:11:45 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |     names = ("train", "valid", "test") | 
					
						
							| 
									
										
										
										
											2024-02-20 05:11:45 +01:00
										 |  |  |     train, valid, test = (load_and_check(n) for n in names) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |     if args.train and len(train) == 0: | 
					
						
							|  |  |  |         raise ValueError( | 
					
						
							|  |  |  |             "Training set not found or empty. Must provide training set for fine-tuning." | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     if args.train and len(valid) == 0: | 
					
						
							|  |  |  |         raise ValueError( | 
					
						
							|  |  |  |             "Validation set not found or empty. Must provide validation set for fine-tuning." | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     if args.test and len(test) == 0: | 
					
						
							|  |  |  |         raise ValueError( | 
					
						
							|  |  |  |             "Test set not found or empty. Must provide test set for evaluation." | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     return train, valid, test | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | def loss(model, inputs, targets, lengths): | 
					
						
							|  |  |  |     # Run model on inputs | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  |     logits, _ = model(inputs) | 
					
						
							| 
									
										
										
										
											2023-12-15 10:29:42 -08:00
										 |  |  |     logits = logits.astype(mx.float32) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Mask padding tokens | 
					
						
							|  |  |  |     length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Calculate the loss | 
					
						
							|  |  |  |     ce = nn.losses.cross_entropy(logits, targets) * length_mask | 
					
						
							|  |  |  |     ntoks = length_mask.sum() | 
					
						
							|  |  |  |     ce = ce.sum() / ntoks | 
					
						
							|  |  |  |     return ce, ntoks | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  | def iterate_batches(dset, tokenizer, batch_size, train=False): | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     # Shuffle indices | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |     while True: | 
					
						
							|  |  |  |         indices = np.arange(len(dset)) | 
					
						
							|  |  |  |         if train: | 
					
						
							|  |  |  |             indices = np.random.permutation(indices) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Collect batches from dataset | 
					
						
							|  |  |  |         for i in range(0, len(indices) - batch_size + 1, batch_size): | 
					
						
							|  |  |  |             # Encode batch | 
					
						
							| 
									
										
										
										
											2024-01-09 19:46:38 -08:00
										 |  |  |             batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)] | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |             lengths = [len(x) for x in batch] | 
					
						
							| 
									
										
										
										
											2023-12-22 22:55:57 +07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-21 06:29:31 -08:00
										 |  |  |             # Check if any sequence is longer than 2048 tokens | 
					
						
							|  |  |  |             if max(lengths) > 2048: | 
					
						
							| 
									
										
										
										
											2023-12-22 22:55:57 +07:00
										 |  |  |                 print( | 
					
						
							|  |  |  |                     "[WARNING] Some sequences are longer than 2048 tokens. " | 
					
						
							|  |  |  |                     "Consider pre-splitting your data to save memory." | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Pad to the max length | 
					
						
							|  |  |  |             batch_arr = np.zeros((batch_size, max(lengths)), np.int32) | 
					
						
							| 
									
										
										
										
											2024-01-09 19:46:38 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |             for j in range(batch_size): | 
					
						
							|  |  |  |                 batch_arr[j, : lengths[j]] = batch[j] | 
					
						
							|  |  |  |             batch = mx.array(batch_arr) | 
					
						
							|  |  |  |             yield batch[:, :-1], batch[:, 1:], mx.array(lengths) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not train: | 
					
						
							|  |  |  |             break | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches): | 
					
						
							|  |  |  |     all_losses = [] | 
					
						
							|  |  |  |     ntokens = 0 | 
					
						
							| 
									
										
										
										
											2024-07-10 11:36:11 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # num_batches can be -1 to indicate the entire set | 
					
						
							|  |  |  |     index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     for it, batch in zip( | 
					
						
							| 
									
										
										
										
											2024-07-10 11:36:11 -04:00
										 |  |  |         index_iterator, | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |         iterate_batches(dataset, tokenizer, batch_size), | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     ): | 
					
						
							|  |  |  |         losses, toks = loss(model, *batch) | 
					
						
							|  |  |  |         all_losses.append((losses * toks).item()) | 
					
						
							|  |  |  |         ntokens += toks.item() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return np.sum(all_losses) / ntokens | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def train(model, train_set, val_set, optimizer, loss, tokenizer, args): | 
					
						
							|  |  |  |     # Create value and grad function for loss | 
					
						
							|  |  |  |     loss_value_and_grad = nn.value_and_grad(model, loss) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     losses = [] | 
					
						
							|  |  |  |     n_tokens = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Main training loop | 
					
						
							|  |  |  |     start = time.perf_counter() | 
					
						
							|  |  |  |     for it, batch in zip( | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |         range(args.iters), | 
					
						
							|  |  |  |         iterate_batches(train_set, tokenizer, args.batch_size, train=True), | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     ): | 
					
						
							|  |  |  |         # Forward and backward pass | 
					
						
							|  |  |  |         (lvalue, toks), grad = loss_value_and_grad(model, *batch) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Model update | 
					
						
							|  |  |  |         optimizer.update(model, grad) | 
					
						
							|  |  |  |         mx.eval(model.parameters(), optimizer.state, lvalue) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Record loss | 
					
						
							|  |  |  |         losses.append(lvalue.item()) | 
					
						
							|  |  |  |         n_tokens += toks.item() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Report training loss if needed | 
					
						
							|  |  |  |         if (it + 1) % args.steps_per_report == 0: | 
					
						
							|  |  |  |             train_loss = np.mean(losses) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             stop = time.perf_counter() | 
					
						
							|  |  |  |             print( | 
					
						
							|  |  |  |                 f"Iter {it + 1}: Train loss {train_loss:.3f}, " | 
					
						
							|  |  |  |                 f"It/sec {args.steps_per_report / (stop - start):.3f}, " | 
					
						
							|  |  |  |                 f"Tokens/sec {float(n_tokens) / (stop - start):.3f}" | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             losses = [] | 
					
						
							|  |  |  |             n_tokens = 0 | 
					
						
							|  |  |  |             start = time.perf_counter() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Report validation loss if needed | 
					
						
							|  |  |  |         if it == 0 or (it + 1) % args.steps_per_eval == 0: | 
					
						
							|  |  |  |             stop = time.perf_counter() | 
					
						
							|  |  |  |             val_loss = evaluate( | 
					
						
							|  |  |  |                 model, val_set, loss, tokenizer, args.batch_size, args.val_batches | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             print( | 
					
						
							|  |  |  |                 f"Iter {it + 1}: " | 
					
						
							|  |  |  |                 f"Val loss {val_loss:.3f}, " | 
					
						
							|  |  |  |                 f"Val took {(time.perf_counter() - stop):.3f}s" | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             start = time.perf_counter() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 05:03:33 +01:00
										 |  |  |         # Save adapter weights if needed | 
					
						
							|  |  |  |         if (it + 1) % args.save_every == 0: | 
					
						
							|  |  |  |             mx.savez( | 
					
						
							|  |  |  |                 args.adapter_file, **dict(tree_flatten(model.trainable_parameters())) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def generate(model, prompt, tokenizer, args): | 
					
						
							| 
									
										
										
										
											2024-01-11 14:04:57 +00:00
										 |  |  |     print(prompt, end="", flush=True) | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-11 14:04:57 +00:00
										 |  |  |     prompt = mx.array(tokenizer.encode(prompt)) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     tokens = [] | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |     skip = 0 | 
					
						
							|  |  |  |     for token, n in zip( | 
					
						
							| 
									
										
										
										
											2024-01-12 13:45:30 -08:00
										 |  |  |         lora_utils.generate(prompt, model, args.temp), | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |         range(args.max_tokens), | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         if token == tokenizer.eos_token_id: | 
					
						
							|  |  |  |             break | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |         tokens.append(token.item()) | 
					
						
							|  |  |  |         s = tokenizer.decode(tokens) | 
					
						
							| 
									
										
										
										
											2024-02-01 11:27:29 +08:00
										 |  |  |         if len(s) - skip > 1: | 
					
						
							|  |  |  |             print(s[skip:-1], end="", flush=True) | 
					
						
							|  |  |  |             skip = len(s) - 1 | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |     print(tokenizer.decode(tokens)[skip:], flush=True) | 
					
						
							|  |  |  |     print("=" * 10) | 
					
						
							|  |  |  |     if len(tokens) == 0: | 
					
						
							|  |  |  |         print("No tokens generated for this prompt") | 
					
						
							|  |  |  |         return | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = build_parser() | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     np.random.seed(args.seed) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-14 08:17:42 +08:00
										 |  |  |     # Building tokenizer_config | 
					
						
							|  |  |  |     tokenizer_config = {} | 
					
						
							|  |  |  |     if args.train: | 
					
						
							|  |  |  |         tokenizer_config["add_eos_token"] = bool(args.add_eos_token) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-14 08:17:42 +08:00
										 |  |  |     print("Loading pretrained model") | 
					
						
							|  |  |  |     model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     # Freeze all layers other than LORA linears | 
					
						
							|  |  |  |     model.freeze() | 
					
						
							| 
									
										
										
										
											2024-01-09 11:14:52 -08:00
										 |  |  |     for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: | 
					
						
							| 
									
										
										
										
											2024-01-12 13:45:30 -08:00
										 |  |  |         l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) | 
					
						
							|  |  |  |         l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) | 
					
						
							| 
									
										
										
										
											2024-01-20 06:07:45 -08:00
										 |  |  |         if hasattr(l, "block_sparse_moe"): | 
					
						
							|  |  |  |             l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 | 
					
						
							|  |  |  |     print(f"Total parameters {p:.3f}M") | 
					
						
							|  |  |  |     p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 | 
					
						
							|  |  |  |     print(f"Trainable parameters {p:.3f}M") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print("Loading datasets") | 
					
						
							| 
									
										
										
										
											2023-12-15 09:56:10 -08:00
										 |  |  |     train_set, valid_set, test_set = load(args) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  |     # Resume training the given adapters. | 
					
						
							|  |  |  |     if args.resume_adapter_file is not None: | 
					
						
							|  |  |  |         print(f"Loading pretrained adapters from {args.resume_adapter_file}") | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  |         model.load_weights(args.resume_adapter_file, strict=False) | 
					
						
							| 
									
										
										
										
											2023-12-09 14:13:55 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |     if args.train: | 
					
						
							|  |  |  |         print("Training") | 
					
						
							|  |  |  |         opt = optim.Adam(learning_rate=args.learning_rate) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Train model | 
					
						
							|  |  |  |         train(model, train_set, valid_set, opt, loss, tokenizer, args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Save adapter weights | 
					
						
							|  |  |  |         mx.savez(args.adapter_file, **dict(tree_flatten(model.trainable_parameters()))) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Load the LoRA adapter weights which we assume should exist by this point | 
					
						
							| 
									
										
										
										
											2024-01-04 21:05:59 -08:00
										 |  |  |     if not Path(args.adapter_file).is_file(): | 
					
						
							|  |  |  |         raise ValueError( | 
					
						
							|  |  |  |             f"Adapter file {args.adapter_file} missing. " | 
					
						
							|  |  |  |             "Use --train to learn and save the adapters.npz." | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     model.load_weights(args.adapter_file, strict=False) | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if args.test: | 
					
						
							|  |  |  |         print("Testing") | 
					
						
							| 
									
										
										
										
											2024-01-20 06:07:45 -08:00
										 |  |  |         model.eval() | 
					
						
							| 
									
										
										
										
											2023-11-29 14:14:11 -08:00
										 |  |  |         test_loss = evaluate( | 
					
						
							|  |  |  |             model, | 
					
						
							|  |  |  |             test_set, | 
					
						
							|  |  |  |             loss, | 
					
						
							|  |  |  |             tokenizer, | 
					
						
							|  |  |  |             args.batch_size, | 
					
						
							|  |  |  |             num_batches=args.test_batches, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         test_ppl = math.exp(test_loss) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if args.prompt is not None: | 
					
						
							|  |  |  |         print("Generating") | 
					
						
							|  |  |  |         generate(model, args.prompt, tokenizer, args) |