mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Forther update to trainer/utils for correct layer selection. Successfull training
This commit is contained in:
@@ -53,9 +53,13 @@ def linear_to_lora_layers(
|
||||
Default: ``False``
|
||||
"""
|
||||
if hasattr(model, "backbone"):
|
||||
num_layers = len(model.backbone.layers)
|
||||
layers = model.backbone.layers
|
||||
elif hasattr(model, "layers"):
|
||||
layers = model.layers
|
||||
else:
|
||||
num_layers = len(model.layers)
|
||||
raise ValueError("Unsupported model structure")
|
||||
|
||||
num_layers = len(layers)
|
||||
|
||||
if num_lora_layers < 0:
|
||||
num_lora_layers = num_layers
|
||||
@@ -143,9 +147,18 @@ def linear_to_lora_layers(
|
||||
"self_attn.kv_b_proj",
|
||||
]
|
||||
)
|
||||
if model.model_type == "mamba":
|
||||
keys = set([
|
||||
"mixer.in_proj",
|
||||
"mixer.x_proj",
|
||||
"mixer.dt_proj",
|
||||
"mixer.out_proj",
|
||||
])
|
||||
else:
|
||||
raise ValueError(f"Lora does not support {model.model_type}")
|
||||
|
||||
# Modified the layer selection to handle both regular and backbone structures:
|
||||
layers = model.backbone.layers if hasattr(model, "backbone") else model.layers
|
||||
for l in model.layers[num_layers - num_lora_layers :]:
|
||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||
if lora_layers:
|
||||
|
Reference in New Issue
Block a user