LoRA: Improve validation error for LoRA layer count exceeding model layer (#427)

* LoRA: Improve validation error for LoRA layer count exceeding model layer

This commit enhances the error handling when the specified LoRA layer count exceeds the total number of layers in the model. It clarifies the error message to provide actionable feedback for users, guiding them to adjust their input parameters accordingly.

* format + nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Madroid Ma 2024-02-13 22:56:27 +08:00 committed by GitHub
parent d4666615bb
commit 954aa50c54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 3 deletions

View File

@ -128,8 +128,8 @@ def train(
loss: callable = default_loss,
):
# Create checkpoints directory if it does not exist
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
if not os.path.exists("checkpoints"):
os.makedirs("checkpoints")
# Create value and grad function for loss
loss_value_and_grad = nn.value_and_grad(model, loss)

View File

@ -16,6 +16,14 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
num_lora_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
"""
def check_lora_layers(num_model):
if num_lora_layers > num_model:
raise ValueError(
f"Requested {num_lora_layers} LoRA layers "
f"but the model only has {num_model_layers} layers."
)
if model.model_type in [
"mistral",
"llama",
@ -24,6 +32,8 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
"stablelm_epoch",
"qwen2",
]:
check_lora_layers(len(model.model.layers))
for l in model.model.layers[len(model.model.layers) - num_lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
@ -32,11 +42,15 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
l.block_sparse_moe.gate
)
elif model.model_type == "olmo":
check_lora_layers(len(model.model.transformer.blocks))
for l in model.model.transformer.blocks[
len(model.model.transformer.blocks) - num_lora_layers :
]:
l.att_proj = LoRALinear.from_linear(l.att_proj)
elif model.model_type == "phi-msft":
check_lora_layers(len(model.transformer.h))
for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]:
l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv)
l.moe.gate = LoRALinear.from_linear(l.moe.gate)

View File

@ -8,7 +8,7 @@ with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
setup(
name="mlx-lm",
version="0.0.8",
version="0.0.10",
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",