Merge branch 'main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez
2024-12-10 14:32:44 +01:00
committed by GitHub
31 changed files with 1579 additions and 414 deletions

View File

@@ -98,6 +98,7 @@ def linear_to_lora_layers(
"cohere",
"minicpm",
"deepseek",
"olmo2",
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type in ["mixtral", "phimoe"]:
@@ -150,6 +151,8 @@ def linear_to_lora_layers(
"mixer.out_proj",
]
)
elif model.model_type == "exaone":
keys = set(["attn.attention.q_proj", "attn.attention.v_proj"])
else:
raise ValueError(f"Lora does not support {model.model_type}")
@@ -256,12 +259,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
return model
def print_trainable_parameters(model):
def nparams(m):
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
return m.weight.size * (32 // m.bits)
return sum(v.size for _, v in tree_flatten(m.parameters()))
def nparams(module):
if hasattr(module, "bits"):
n = 0 if not hasattr(module, "bias") else module.bias.size
return n + module.weight.size * 32 // module.bits
return sum(v.size for _, v in tree_flatten(module.parameters()))
def print_trainable_parameters(model):
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)