mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Some improvements to LoRA (#528)
* set cache_limit * remove set cache_limit * cleanup * add gradient checkpointing * fix sort * mokey patch call for checkpoint * fix example config
This commit is contained in:
@@ -35,9 +35,6 @@ def linear_to_lora_layers(
|
||||
lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"]
|
||||
)
|
||||
|
||||
# If the lora_parameters are set, we assume the keys
|
||||
# are correct for the given model
|
||||
|
||||
keys = config.get("keys", None)
|
||||
if keys is not None:
|
||||
keys = set(keys)
|
||||
@@ -53,7 +50,7 @@ def linear_to_lora_layers(
|
||||
]:
|
||||
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
||||
if model.model_type == "mixtral":
|
||||
keys.add(["block_sparse_moe.gate"])
|
||||
keys.add("block_sparse_moe.gate")
|
||||
elif model.model_type == "olmo":
|
||||
keys = set(["att_proj"])
|
||||
elif model.model_type == "phi-msft":
|
||||
|
||||
Reference in New Issue
Block a user