Add a multi optimizer (#1916)

This commit is contained in:
Angelos Katharopoulos
2025-03-04 13:16:35 -08:00
committed by GitHub
parent a0737273d3
commit 9680f72cca
5 changed files with 168 additions and 1 deletions

View File

@@ -39,6 +39,7 @@ def tree_equal(fn, *args):
optimizers_dict = get_all_optimizers()
del optimizers_dict["MultiOptimizer"]
class TestOptimizers(mlx_tests.MLXTestCase):
@@ -500,6 +501,30 @@ class TestSchedulers(unittest.TestCase):
grads = model.trainable_parameters()
optimizer.update(model, grads)
def test_multi_optimizer(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(2, 2)
self.drop = nn.Dropout(p=0.5)
self.l2 = nn.Linear(2, 2)
self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
model = Model()
optimizer = opt.MultiOptimizer(
[opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)],
[lambda name, weight: weight.ndim > 1],
)
optimizer.init(model.trainable_parameters())
self.assertEqual(len(optimizer.state["states"]), 2)
adam_states = tree_flatten(optimizer.state["states"][0])
sgd_states = tree_flatten(optimizer.state["states"][1])
self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2)
self.assertFalse(any("bias" in k for k, v in adam_states))
self.assertFalse(any("weight" in k for k, v in sgd_states))
if __name__ == "__main__":
unittest.main()