mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Lazy import + refactor Lora layer addition (#426)
* lazy model import in mlx_lm * change lora loading * fix olmo lora * remove a bunch of unused stuff from plamo * move phixtral to mlx-lm and out of llms/
This commit is contained in:
@@ -9,7 +9,8 @@ from mlx.utils import tree_flatten
|
||||
|
||||
from .tuner.lora import LoRALinear
|
||||
from .tuner.trainer import TrainingArgs, evaluate, train
|
||||
from .utils import LORA_SUPPORTED_MODELS, generate, load
|
||||
from .tuner.utils import linear_to_lora_layers
|
||||
from .utils import generate, load
|
||||
|
||||
|
||||
def build_parser():
|
||||
@@ -169,19 +170,10 @@ if __name__ == "__main__":
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
if model.__class__ not in LORA_SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {model.__class__} not supported. "
|
||||
f"Supported models: {LORA_SUPPORTED_MODELS}"
|
||||
)
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
# Freeze all layers
|
||||
model.freeze()
|
||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
if hasattr(l, "block_sparse_moe"):
|
||||
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
|
||||
# Convert linear layers to lora layers and unfreeze in the process
|
||||
linear_to_lora_layers(model, args.lora_layers)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
|
Reference in New Issue
Block a user