Add support for OpenELM (#719)

* add openELM

* update splitting logic

* update qkv logic and, transformer and MLP block

* code formatting and fix args

* fix array slicing and remove unused var :)

* add to tuner

* use mx.split for slicing qkv

* merge with phi3

* remove rope scaling logic

* code formatting
This commit is contained in:
Prince Canuma
2024-04-26 01:49:28 +02:00
committed by GitHub
parent 2c1c9e9024
commit c012eb173f
2 changed files with 229 additions and 1 deletions

View File

@@ -87,7 +87,7 @@ def linear_to_lora_layers(
keys.add("mlp.shared_expert_gate")
elif model.model_type == "olmo":
keys = set(["att_proj"])
elif model.model_type == "phi3":
elif model.model_type in ["phi3", "openelm"]:
keys = set(["self_attn.qkv_proj"])
elif model.model_type == "phi-msft":
keys = set(["mixer.Wqkv", "moe.gate"])