mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
LoRA: Improve validation error for LoRA layer count exceeding model layer (#427)
* LoRA: Improve validation error for LoRA layer count exceeding model layer This commit enhances the error handling when the specified LoRA layer count exceeds the total number of layers in the model. It clarifies the error message to provide actionable feedback for users, guiding them to adjust their input parameters accordingly. * format + nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
d4666615bb
commit
954aa50c54
@ -128,8 +128,8 @@ def train(
|
||||
loss: callable = default_loss,
|
||||
):
|
||||
# Create checkpoints directory if it does not exist
|
||||
if not os.path.exists('checkpoints'):
|
||||
os.makedirs('checkpoints')
|
||||
if not os.path.exists("checkpoints"):
|
||||
os.makedirs("checkpoints")
|
||||
|
||||
# Create value and grad function for loss
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
|
@ -16,6 +16,14 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
||||
starting from the last layer.
|
||||
"""
|
||||
|
||||
def check_lora_layers(num_model):
|
||||
if num_lora_layers > num_model:
|
||||
raise ValueError(
|
||||
f"Requested {num_lora_layers} LoRA layers "
|
||||
f"but the model only has {num_model_layers} layers."
|
||||
)
|
||||
|
||||
if model.model_type in [
|
||||
"mistral",
|
||||
"llama",
|
||||
@ -24,6 +32,8 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
||||
"stablelm_epoch",
|
||||
"qwen2",
|
||||
]:
|
||||
check_lora_layers(len(model.model.layers))
|
||||
|
||||
for l in model.model.layers[len(model.model.layers) - num_lora_layers :]:
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
@ -32,11 +42,15 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
||||
l.block_sparse_moe.gate
|
||||
)
|
||||
elif model.model_type == "olmo":
|
||||
check_lora_layers(len(model.model.transformer.blocks))
|
||||
|
||||
for l in model.model.transformer.blocks[
|
||||
len(model.model.transformer.blocks) - num_lora_layers :
|
||||
]:
|
||||
l.att_proj = LoRALinear.from_linear(l.att_proj)
|
||||
elif model.model_type == "phi-msft":
|
||||
check_lora_layers(len(model.transformer.h))
|
||||
|
||||
for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]:
|
||||
l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv)
|
||||
l.moe.gate = LoRALinear.from_linear(l.moe.gate)
|
||||
|
@ -8,7 +8,7 @@ with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid:
|
||||
requirements = [str(r) for r in pkg_resources.parse_requirements(fid)]
|
||||
setup(
|
||||
name="mlx-lm",
|
||||
version="0.0.8",
|
||||
version="0.0.10",
|
||||
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
Loading…
Reference in New Issue
Block a user