Some improvements to LoRA (#528)

* set cache_limit

* remove set cache_limit

* cleanup

* add gradient checkpointing

* fix sort

* mokey patch call for checkpoint

* fix example config
This commit is contained in:
Awni Hannun
2024-03-12 20:02:03 -07:00
committed by GitHub
parent e56d9015ef
commit 39084e81c2
4 changed files with 68 additions and 25 deletions

View File

@@ -35,9 +35,6 @@ def linear_to_lora_layers(
lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"]
)
# If the lora_parameters are set, we assume the keys
# are correct for the given model
keys = config.get("keys", None)
if keys is not None:
keys = set(keys)
@@ -53,7 +50,7 @@ def linear_to_lora_layers(
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type == "mixtral":
keys.add(["block_sparse_moe.gate"])
keys.add("block_sparse_moe.gate")
elif model.model_type == "olmo":
keys = set(["att_proj"])
elif model.model_type == "phi-msft":