LoRA: Improve validation error for LoRA layer count exceeding model layer (#427)

* LoRA: Improve validation error for LoRA layer count exceeding model layer

This commit enhances the error handling when the specified LoRA layer count exceeds the total number of layers in the model. It clarifies the error message to provide actionable feedback for users, guiding them to adjust their input parameters accordingly.

* format + nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Madroid Ma
2024-02-13 22:56:27 +08:00
committed by GitHub
parent d4666615bb
commit 954aa50c54
3 changed files with 17 additions and 3 deletions

View File

@@ -128,8 +128,8 @@ def train(
loss: callable = default_loss,
):
# Create checkpoints directory if it does not exist
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
if not os.path.exists("checkpoints"):
os.makedirs("checkpoints")
# Create value and grad function for loss
loss_value_and_grad = nn.value_and_grad(model, loss)