From 9680f72ccaae39fb203dbbeb3d29baf7f4cd090e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 4 Mar 2025 13:16:35 -0800 Subject: [PATCH] Add a multi optimizer (#1916) --- python/mlx/optimizers/optimizers.py | 75 ++++++++++++++++++++++++++++- python/mlx/utils.py | 44 +++++++++++++++++ python/tests/test_optimizers.py | 25 ++++++++++ python/tests/test_tree.py | 24 +++++++++ python/tests/test_vmap.py | 1 + 5 files changed, 168 insertions(+), 1 deletion(-) diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 3d40dd0d1..36068403d 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Tuple, Union import mlx.core as mx from mlx.nn import Module -from mlx.utils import tree_map, tree_reduce +from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten class Optimizer: @@ -154,6 +154,79 @@ class Optimizer: self.state[name] = param +class MultiOptimizer(Optimizer): + """Wraps a list of optimizers with corresponding weight predicates/filters + to make it easy to use different optimizers for different weights. + + The predicates take the full "path" of the weight and the weight itself and + return True if it should be considered for this optimizer. The last + optimizer in the list is a fallback optimizer and no predicate should be + given for it. + + Args: + optimizers (list[Optimizer]): A list of optimizers to delegate to + filters (list[Callable[[str, array], bool]): A list of predicates that + should be one less than the provided optimizers. + """ + + def __init__(self, optimizers, filters: list = []): + super().__init__() + self._state = {} + + if len(filters) != len(optimizers) - 1: + raise ValueError( + f"Given {len(filters)} filters but {len(optimizers)-1} needed." + ) + + self.optimizers = optimizers + self.filters = filters + [lambda *args, **kwargs: True] + + def _split_dictionary(self, gradients: dict): + if len(self.optimizers) == 1: + return [gradients] + + parts = [[] for _ in range(len(self.optimizers))] + flat_gradients = tree_flatten(gradients) + for k, g in flat_gradients: + for i, fn in enumerate(self.filters): + if fn(k, g): + parts[i].append((k, g)) + break + + return [tree_unflatten(p) for p in parts] + + def init(self, parameters: dict): + for o, p in zip(self.optimizers, self._split_dictionary(parameters)): + o.init(p) + + def apply_gradients(self, gradients: dict, parameters: dict): + tree = {} + for o, g in zip(self.optimizers, self._split_dictionary(gradients)): + tree = tree_merge(tree, o.apply_gradients(g, parameters)) + return tree + + @property + def state(self): + return {"states": [o.state for o in self.optimizers]} + + @state.setter + def state(self, state: dict): + if "states" not in state or len(state["states"]) != len(self.optimizers): + raise ValueError("Invalid state provided") + + for o, s in zip(self.optimizers, state["states"]): + o.state = s + + @property + def learning_rate(self): + return self.optimizers[0].learning_rate + + @learning_rate.setter + def learning_rate(self, learning_rate: Union[float, mx.array]): + for o in self.optimizers: + o.learning_rate = learning_rate + + class SGD(Optimizer): r"""The stochastic gradient descent optimizer. diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 6754232a6..2a3c1e660 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -1,5 +1,6 @@ # Copyright © 2023 Apple Inc. from collections import defaultdict +from itertools import zip_longest from typing import Any, Callable, List, Optional, Tuple @@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None): return tree if accumulator is None else fn(accumulator, tree) return accumulator + + +def tree_merge(tree_a, tree_b, merge_fn=None): + """Merge two Python trees in one containing the values of both. It can be + thought of as a deep dict.update method. + + Args: + tree_a (Any): The first Python tree. + tree_b (Any): The second Python tree. + merge_fn (callable, optional): A function to merge leaves. + + Returns: + The Python tree containing the values of both ``tree_a`` and + ``tree_b``. + """ + if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0: + tree_a = None + if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0: + tree_b = None + if tree_a is None and tree_b is not None: + return tree_b + if tree_a is not None and tree_b is None: + return tree_a + + if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)): + TreeType = type(tree_a) + return TreeType( + tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b) + ) + elif isinstance(tree_a, dict) and isinstance(tree_b, dict): + return { + k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn) + for k in set(tree_a.keys()) | set(tree_b.keys()) + } + else: + if merge_fn is None: + raise ValueError( + ( + "Trees contain elements at the same locations but no merge " + "function was provided" + ) + ) + return merge_fn(tree_a, tree_b) diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index cf3a2b4fa..ebfe97d80 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -39,6 +39,7 @@ def tree_equal(fn, *args): optimizers_dict = get_all_optimizers() +del optimizers_dict["MultiOptimizer"] class TestOptimizers(mlx_tests.MLXTestCase): @@ -500,6 +501,30 @@ class TestSchedulers(unittest.TestCase): grads = model.trainable_parameters() optimizer.update(model, grads) + def test_multi_optimizer(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(2, 2) + self.drop = nn.Dropout(p=0.5) + self.l2 = nn.Linear(2, 2) + self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()] + + model = Model() + optimizer = opt.MultiOptimizer( + [opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)], + [lambda name, weight: weight.ndim > 1], + ) + optimizer.init(model.trainable_parameters()) + + self.assertEqual(len(optimizer.state["states"]), 2) + + adam_states = tree_flatten(optimizer.state["states"][0]) + sgd_states = tree_flatten(optimizer.state["states"][1]) + self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2) + self.assertFalse(any("bias" in k for k, v in adam_states)) + self.assertFalse(any("weight" in k for k, v in sgd_states)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_tree.py b/python/tests/test_tree.py index cab137b78..63018fdae 100644 --- a/python/tests/test_tree.py +++ b/python/tests/test_tree.py @@ -3,6 +3,7 @@ import unittest import mlx.core as mx +import mlx.nn as nn import mlx.utils import mlx_tests @@ -22,6 +23,29 @@ class TestTreeUtils(mlx_tests.MLXTestCase): self.assertEqual(list(zip(*flat_tree))[1], vals) self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree) + def test_merge(self): + t1 = {"a": 0} + t2 = {"b": 1} + t = mlx.utils.tree_merge(t1, t2) + self.assertEqual({"a": 0, "b": 1}, t) + with self.assertRaises(ValueError): + mlx.utils.tree_merge(t1, t1) + with self.assertRaises(ValueError): + mlx.utils.tree_merge(t, t1) + + mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) + mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) + mod = nn.Sequential(mod1, mod2) + + params1 = {"layers": [mod1.parameters()]} + params2 = {"layers": [None, mod2.parameters()]} + params = mlx.utils.tree_merge(params1, params2) + for (k1, v1), (k2, v2) in zip( + mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters()) + ): + self.assertEqual(k1, k2) + self.assertTrue(mx.array_equal(v1, v2)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 2d38bc457..81b74d98c 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -345,6 +345,7 @@ class TestVmap(mlx_tests.MLXTestCase): ) def test_vmap_inverse(self): + mx.random.seed(42) a = mx.random.uniform(shape=(3, 4, 4)) cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)