LoRA: Extract small function (#614)

* LoRA: Extract pre_processing_model  function

* LoRA: Extract small functions(train_model,evaluate_model)

* move test case to test_tuner_utils.py

* nits

* nits

* remove extra param, validate at it 0

* version

* fix test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid
2024-06-02 21:38:42 +08:00
committed by GitHub
parent 81318ad4a8
commit c457a3f88b
10 changed files with 232 additions and 206 deletions

View File

@@ -29,9 +29,6 @@ def grad_checkpoint(layer):
@dataclass
class TrainingArgs:
lora_layers: int = field(
default=16, metadata={"help": "Number of layers to fine-tune"}
)
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
val_batches: int = field(
@@ -211,6 +208,35 @@ def train(
train=True,
),
):
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
lvalue, toks = step(batch)
mx.eval(state, lvalue, toks)
@@ -220,9 +246,9 @@ def train(
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
train_loss = np.mean(losses)
stop = time.perf_counter()
train_loss = np.mean(losses)
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
@@ -253,34 +279,6 @@ def train(
n_tokens = 0
start = time.perf_counter()
# Report validation loss if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
save_adapter(model, args.adapter_file)