mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-21 20:58:08 +08:00
Add a multi optimizer (#1916)
This commit is contained in:

committed by
GitHub

parent
a0737273d3
commit
9680f72cca
@@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn import Module
|
||||
from mlx.utils import tree_map, tree_reduce
|
||||
from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten
|
||||
|
||||
|
||||
class Optimizer:
|
||||
@@ -154,6 +154,79 @@ class Optimizer:
|
||||
self.state[name] = param
|
||||
|
||||
|
||||
class MultiOptimizer(Optimizer):
|
||||
"""Wraps a list of optimizers with corresponding weight predicates/filters
|
||||
to make it easy to use different optimizers for different weights.
|
||||
|
||||
The predicates take the full "path" of the weight and the weight itself and
|
||||
return True if it should be considered for this optimizer. The last
|
||||
optimizer in the list is a fallback optimizer and no predicate should be
|
||||
given for it.
|
||||
|
||||
Args:
|
||||
optimizers (list[Optimizer]): A list of optimizers to delegate to
|
||||
filters (list[Callable[[str, array], bool]): A list of predicates that
|
||||
should be one less than the provided optimizers.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizers, filters: list = []):
|
||||
super().__init__()
|
||||
self._state = {}
|
||||
|
||||
if len(filters) != len(optimizers) - 1:
|
||||
raise ValueError(
|
||||
f"Given {len(filters)} filters but {len(optimizers)-1} needed."
|
||||
)
|
||||
|
||||
self.optimizers = optimizers
|
||||
self.filters = filters + [lambda *args, **kwargs: True]
|
||||
|
||||
def _split_dictionary(self, gradients: dict):
|
||||
if len(self.optimizers) == 1:
|
||||
return [gradients]
|
||||
|
||||
parts = [[] for _ in range(len(self.optimizers))]
|
||||
flat_gradients = tree_flatten(gradients)
|
||||
for k, g in flat_gradients:
|
||||
for i, fn in enumerate(self.filters):
|
||||
if fn(k, g):
|
||||
parts[i].append((k, g))
|
||||
break
|
||||
|
||||
return [tree_unflatten(p) for p in parts]
|
||||
|
||||
def init(self, parameters: dict):
|
||||
for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
|
||||
o.init(p)
|
||||
|
||||
def apply_gradients(self, gradients: dict, parameters: dict):
|
||||
tree = {}
|
||||
for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
|
||||
tree = tree_merge(tree, o.apply_gradients(g, parameters))
|
||||
return tree
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return {"states": [o.state for o in self.optimizers]}
|
||||
|
||||
@state.setter
|
||||
def state(self, state: dict):
|
||||
if "states" not in state or len(state["states"]) != len(self.optimizers):
|
||||
raise ValueError("Invalid state provided")
|
||||
|
||||
for o, s in zip(self.optimizers, state["states"]):
|
||||
o.state = s
|
||||
|
||||
@property
|
||||
def learning_rate(self):
|
||||
return self.optimizers[0].learning_rate
|
||||
|
||||
@learning_rate.setter
|
||||
def learning_rate(self, learning_rate: Union[float, mx.array]):
|
||||
for o in self.optimizers:
|
||||
o.learning_rate = learning_rate
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""The stochastic gradient descent optimizer.
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
from collections import defaultdict
|
||||
from itertools import zip_longest
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
|
||||
@@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
|
||||
return tree if accumulator is None else fn(accumulator, tree)
|
||||
|
||||
return accumulator
|
||||
|
||||
|
||||
def tree_merge(tree_a, tree_b, merge_fn=None):
|
||||
"""Merge two Python trees in one containing the values of both. It can be
|
||||
thought of as a deep dict.update method.
|
||||
|
||||
Args:
|
||||
tree_a (Any): The first Python tree.
|
||||
tree_b (Any): The second Python tree.
|
||||
merge_fn (callable, optional): A function to merge leaves.
|
||||
|
||||
Returns:
|
||||
The Python tree containing the values of both ``tree_a`` and
|
||||
``tree_b``.
|
||||
"""
|
||||
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
|
||||
tree_a = None
|
||||
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
|
||||
tree_b = None
|
||||
if tree_a is None and tree_b is not None:
|
||||
return tree_b
|
||||
if tree_a is not None and tree_b is None:
|
||||
return tree_a
|
||||
|
||||
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
|
||||
TreeType = type(tree_a)
|
||||
return TreeType(
|
||||
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
|
||||
)
|
||||
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
|
||||
return {
|
||||
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
|
||||
for k in set(tree_a.keys()) | set(tree_b.keys())
|
||||
}
|
||||
else:
|
||||
if merge_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
"Trees contain elements at the same locations but no merge "
|
||||
"function was provided"
|
||||
)
|
||||
)
|
||||
return merge_fn(tree_a, tree_b)
|
||||
|
Reference in New Issue
Block a user