mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
chore(mlx-lm): fix print_trainable_parameters for quant models (#581)
* chore(mlx-lm): fix print_trainable_parameters for quant models * chore: clean up * refactor: use layer type to check quant bits * chore: address comment
This commit is contained in:
@@ -7,6 +7,7 @@ import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import yaml
|
||||
@@ -143,7 +144,15 @@ def build_parser():
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
def nparams(m):
|
||||
if isinstance(m, nn.QuantizedLinear):
|
||||
return m.weight.size * (32 // m.bits)
|
||||
return sum(v.size for _, v in tree_flatten(m.parameters()))
|
||||
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
|
||||
trainable_p = (
|
||||
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||
)
|
||||
|
Reference in New Issue
Block a user