mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add a multi optimizer (#1916)
This commit is contained in:
parent
a0737273d3
commit
9680f72cca
@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn import Module
|
from mlx.nn import Module
|
||||||
from mlx.utils import tree_map, tree_reduce
|
from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
class Optimizer:
|
class Optimizer:
|
||||||
@ -154,6 +154,79 @@ class Optimizer:
|
|||||||
self.state[name] = param
|
self.state[name] = param
|
||||||
|
|
||||||
|
|
||||||
|
class MultiOptimizer(Optimizer):
|
||||||
|
"""Wraps a list of optimizers with corresponding weight predicates/filters
|
||||||
|
to make it easy to use different optimizers for different weights.
|
||||||
|
|
||||||
|
The predicates take the full "path" of the weight and the weight itself and
|
||||||
|
return True if it should be considered for this optimizer. The last
|
||||||
|
optimizer in the list is a fallback optimizer and no predicate should be
|
||||||
|
given for it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizers (list[Optimizer]): A list of optimizers to delegate to
|
||||||
|
filters (list[Callable[[str, array], bool]): A list of predicates that
|
||||||
|
should be one less than the provided optimizers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optimizers, filters: list = []):
|
||||||
|
super().__init__()
|
||||||
|
self._state = {}
|
||||||
|
|
||||||
|
if len(filters) != len(optimizers) - 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Given {len(filters)} filters but {len(optimizers)-1} needed."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizers = optimizers
|
||||||
|
self.filters = filters + [lambda *args, **kwargs: True]
|
||||||
|
|
||||||
|
def _split_dictionary(self, gradients: dict):
|
||||||
|
if len(self.optimizers) == 1:
|
||||||
|
return [gradients]
|
||||||
|
|
||||||
|
parts = [[] for _ in range(len(self.optimizers))]
|
||||||
|
flat_gradients = tree_flatten(gradients)
|
||||||
|
for k, g in flat_gradients:
|
||||||
|
for i, fn in enumerate(self.filters):
|
||||||
|
if fn(k, g):
|
||||||
|
parts[i].append((k, g))
|
||||||
|
break
|
||||||
|
|
||||||
|
return [tree_unflatten(p) for p in parts]
|
||||||
|
|
||||||
|
def init(self, parameters: dict):
|
||||||
|
for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
|
||||||
|
o.init(p)
|
||||||
|
|
||||||
|
def apply_gradients(self, gradients: dict, parameters: dict):
|
||||||
|
tree = {}
|
||||||
|
for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
|
||||||
|
tree = tree_merge(tree, o.apply_gradients(g, parameters))
|
||||||
|
return tree
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return {"states": [o.state for o in self.optimizers]}
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, state: dict):
|
||||||
|
if "states" not in state or len(state["states"]) != len(self.optimizers):
|
||||||
|
raise ValueError("Invalid state provided")
|
||||||
|
|
||||||
|
for o, s in zip(self.optimizers, state["states"]):
|
||||||
|
o.state = s
|
||||||
|
|
||||||
|
@property
|
||||||
|
def learning_rate(self):
|
||||||
|
return self.optimizers[0].learning_rate
|
||||||
|
|
||||||
|
@learning_rate.setter
|
||||||
|
def learning_rate(self, learning_rate: Union[float, mx.array]):
|
||||||
|
for o in self.optimizers:
|
||||||
|
o.learning_rate = learning_rate
|
||||||
|
|
||||||
|
|
||||||
class SGD(Optimizer):
|
class SGD(Optimizer):
|
||||||
r"""The stochastic gradient descent optimizer.
|
r"""The stochastic gradient descent optimizer.
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from itertools import zip_longest
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
|
|||||||
return tree if accumulator is None else fn(accumulator, tree)
|
return tree if accumulator is None else fn(accumulator, tree)
|
||||||
|
|
||||||
return accumulator
|
return accumulator
|
||||||
|
|
||||||
|
|
||||||
|
def tree_merge(tree_a, tree_b, merge_fn=None):
|
||||||
|
"""Merge two Python trees in one containing the values of both. It can be
|
||||||
|
thought of as a deep dict.update method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree_a (Any): The first Python tree.
|
||||||
|
tree_b (Any): The second Python tree.
|
||||||
|
merge_fn (callable, optional): A function to merge leaves.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Python tree containing the values of both ``tree_a`` and
|
||||||
|
``tree_b``.
|
||||||
|
"""
|
||||||
|
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
|
||||||
|
tree_a = None
|
||||||
|
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
|
||||||
|
tree_b = None
|
||||||
|
if tree_a is None and tree_b is not None:
|
||||||
|
return tree_b
|
||||||
|
if tree_a is not None and tree_b is None:
|
||||||
|
return tree_a
|
||||||
|
|
||||||
|
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
|
||||||
|
TreeType = type(tree_a)
|
||||||
|
return TreeType(
|
||||||
|
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
|
||||||
|
)
|
||||||
|
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
|
||||||
|
return {
|
||||||
|
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
|
||||||
|
for k in set(tree_a.keys()) | set(tree_b.keys())
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if merge_fn is None:
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
"Trees contain elements at the same locations but no merge "
|
||||||
|
"function was provided"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return merge_fn(tree_a, tree_b)
|
||||||
|
@ -39,6 +39,7 @@ def tree_equal(fn, *args):
|
|||||||
|
|
||||||
|
|
||||||
optimizers_dict = get_all_optimizers()
|
optimizers_dict = get_all_optimizers()
|
||||||
|
del optimizers_dict["MultiOptimizer"]
|
||||||
|
|
||||||
|
|
||||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||||
@ -500,6 +501,30 @@ class TestSchedulers(unittest.TestCase):
|
|||||||
grads = model.trainable_parameters()
|
grads = model.trainable_parameters()
|
||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
|
def test_multi_optimizer(self):
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.l1 = nn.Linear(2, 2)
|
||||||
|
self.drop = nn.Dropout(p=0.5)
|
||||||
|
self.l2 = nn.Linear(2, 2)
|
||||||
|
self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
optimizer = opt.MultiOptimizer(
|
||||||
|
[opt.Adam(learning_rate=0.001), opt.SGD(learning_rate=0.1)],
|
||||||
|
[lambda name, weight: weight.ndim > 1],
|
||||||
|
)
|
||||||
|
optimizer.init(model.trainable_parameters())
|
||||||
|
|
||||||
|
self.assertEqual(len(optimizer.state["states"]), 2)
|
||||||
|
|
||||||
|
adam_states = tree_flatten(optimizer.state["states"][0])
|
||||||
|
sgd_states = tree_flatten(optimizer.state["states"][1])
|
||||||
|
self.assertEqual((len(sgd_states) - 2) * 2, len(adam_states) - 2)
|
||||||
|
self.assertFalse(any("bias" in k for k, v in adam_states))
|
||||||
|
self.assertFalse(any("weight" in k for k, v in sgd_states))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
import mlx.utils
|
import mlx.utils
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
@ -22,6 +23,29 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(list(zip(*flat_tree))[1], vals)
|
self.assertEqual(list(zip(*flat_tree))[1], vals)
|
||||||
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
|
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
|
||||||
|
|
||||||
|
def test_merge(self):
|
||||||
|
t1 = {"a": 0}
|
||||||
|
t2 = {"b": 1}
|
||||||
|
t = mlx.utils.tree_merge(t1, t2)
|
||||||
|
self.assertEqual({"a": 0, "b": 1}, t)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mlx.utils.tree_merge(t1, t1)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mlx.utils.tree_merge(t, t1)
|
||||||
|
|
||||||
|
mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
||||||
|
mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
||||||
|
mod = nn.Sequential(mod1, mod2)
|
||||||
|
|
||||||
|
params1 = {"layers": [mod1.parameters()]}
|
||||||
|
params2 = {"layers": [None, mod2.parameters()]}
|
||||||
|
params = mlx.utils.tree_merge(params1, params2)
|
||||||
|
for (k1, v1), (k2, v2) in zip(
|
||||||
|
mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters())
|
||||||
|
):
|
||||||
|
self.assertEqual(k1, k2)
|
||||||
|
self.assertTrue(mx.array_equal(v1, v2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -345,6 +345,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_vmap_inverse(self):
|
def test_vmap_inverse(self):
|
||||||
|
mx.random.seed(42)
|
||||||
a = mx.random.uniform(shape=(3, 4, 4))
|
a = mx.random.uniform(shape=(3, 4, 4))
|
||||||
|
|
||||||
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)
|
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)
|
||||||
|
Loading…
Reference in New Issue
Block a user