Add a multi optimizer (#1916)

This commit is contained in:
Angelos Katharopoulos 2025-03-04 13:16:35 -08:00 committed by GitHub
parent a0737273d3
commit 9680f72cca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 168 additions and 1 deletions

View File

@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.nn import Module
from mlx.utils import tree_map, tree_reduce
from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten
class Optimizer:
@ -154,6 +154,79 @@ class Optimizer:
self.state[name] = param
class MultiOptimizer(Optimizer):
"""Wraps a list of optimizers with corresponding weight predicates/filters
to make it easy to use different optimizers for different weights.
The predicates take the full "path" of the weight and the weight itself and
return True if it should be considered for this optimizer. The last
optimizer in the list is a fallback optimizer and no predicate should be
given for it.
Args:
optimizers (list[Optimizer]): A list of optimizers to delegate to
filters (list[Callable[[str, array], bool]): A list of predicates that
should be one less than the provided optimizers.
"""
def __init__(self, optimizers, filters: list = []):
super().__init__()
self._state = {}
if len(filters) != len(optimizers) - 1:
raise ValueError(
f"Given {len(filters)} filters but {len(optimizers)-1} needed."
)
self.optimizers = optimizers
self.filters = filters + [lambda *args, **kwargs: True]
def _split_dictionary(self, gradients: dict):
if len(self.optimizers) == 1:
return [gradients]
parts = [[] for _ in range(len(self.optimizers))]
flat_gradients = tree_flatten(gradients)
for k, g in flat_gradients:
for i, fn in enumerate(self.filters):
if fn(k, g):
parts[i].append((k, g))
break
return [tree_unflatten(p) for p in parts]
def init(self, parameters: dict):
for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
o.init(p)
def apply_gradients(self, gradients: dict, parameters: dict):
tree = {}
for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
tree = tree_merge(tree, o.apply_gradients(g, parameters))
return tree
@property
def state(self):
return {"states": [o.state for o in self.optimizers]}
@state.setter
def state(self, state: dict):
if "states" not in state or len(state["states"]) != len(self.optimizers):
raise ValueError("Invalid state provided")
for o, s in zip(self.optimizers, state["states"]):
o.state = s
@property
def learning_rate(self):
return self.optimizers[0].learning_rate
@learning_rate.setter
def learning_rate(self, learning_rate: Union[float, mx.array]):
for o in self.optimizers:
o.learning_rate = learning_rate
class SGD(Optimizer):
r"""The stochastic gradient descent optimizer.

View File

@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple
@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
return tree if accumulator is None else fn(accumulator, tree)
return accumulator
def tree_merge(tree_a, tree_b, merge_fn=None):
"""Merge two Python trees in one containing the values of both. It can be
thought of as a deep dict.update method.
Args:
tree_a (Any): The first Python tree.
tree_b (Any): The second Python tree.
merge_fn (callable, optional): A function to merge leaves.
Returns:
The Python tree containing the values of both ``tree_a`` and
``tree_b``.
"""
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
tree_a = None
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
tree_b = None
if tree_a is None and tree_b is not None:
return tree_b
if tree_a is not None and tree_b is None:
return tree_a
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
TreeType = type(tree_a)
return TreeType(
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
)
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
return {
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
for k in set(tree_a.keys()) | set(tree_b.keys())
}
else:
if merge_fn is None:
raise ValueError(
(
"Trees contain elements at the same locations but no merge "
"function was provided"
)
)
return merge_fn(tree_a, tree_b)

View File

@ -39,6 +39,7 @@ def tree_equal(fn, *args):
optimizers_dict = get_all_optimizers()
del optimizers_dict["MultiOptimizer"]
class TestOptimizers(mlx_tests.MLXTestCase):
@ -500,6 +501,30 @@ class TestSchedulers(unittest.TestCase):
grads = model.trainable_parameters()
optimizer.update(model, grads)
def test_multi_optimizer(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(2, 2)
self.drop = nn.Dropout(p=0.5)
self.l2 = nn.Linear(2, 2)
self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
model = Model()
optimizer = opt.MultiOptimizer(
[opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)],
[lambda name, weight: weight.ndim > 1],
)
optimizer.init(model.trainable_parameters())
self.assertEqual(len(optimizer.state["states"]), 2)
adam_states = tree_flatten(optimizer.state["states"][0])
sgd_states = tree_flatten(optimizer.state["states"][1])
self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2)
self.assertFalse(any("bias" in k for k, v in adam_states))
self.assertFalse(any("weight" in k for k, v in sgd_states))
if __name__ == "__main__":
unittest.main()

View File

@ -3,6 +3,7 @@
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
import mlx_tests
@ -22,6 +23,29 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
self.assertEqual(list(zip(*flat_tree))[1], vals)
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
def test_merge(self):
t1 = {"a": 0}
t2 = {"b": 1}
t = mlx.utils.tree_merge(t1, t2)
self.assertEqual({"a": 0, "b": 1}, t)
with self.assertRaises(ValueError):
mlx.utils.tree_merge(t1, t1)
with self.assertRaises(ValueError):
mlx.utils.tree_merge(t, t1)
mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
mod = nn.Sequential(mod1, mod2)
params1 = {"layers": [mod1.parameters()]}
params2 = {"layers": [None, mod2.parameters()]}
params = mlx.utils.tree_merge(params1, params2)
for (k1, v1), (k2, v2) in zip(
mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters())
):
self.assertEqual(k1, k2)
self.assertTrue(mx.array_equal(v1, v2))
if __name__ == "__main__":
unittest.main()

View File

@ -345,6 +345,7 @@ class TestVmap(mlx_tests.MLXTestCase):
)
def test_vmap_inverse(self):
mx.random.seed(42)
a = mx.random.uniform(shape=(3, 4, 4))
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)