Add a multi optimizer (#1916)

This commit is contained in:
Angelos Katharopoulos
2025-03-04 13:16:35 -08:00
committed by GitHub
parent a0737273d3
commit 9680f72cca
5 changed files with 168 additions and 1 deletions

View File

@@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.nn import Module
from mlx.utils import tree_map, tree_reduce
from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten
class Optimizer:
@@ -154,6 +154,79 @@ class Optimizer:
self.state[name] = param
class MultiOptimizer(Optimizer):
"""Wraps a list of optimizers with corresponding weight predicates/filters
to make it easy to use different optimizers for different weights.
The predicates take the full "path" of the weight and the weight itself and
return True if it should be considered for this optimizer. The last
optimizer in the list is a fallback optimizer and no predicate should be
given for it.
Args:
optimizers (list[Optimizer]): A list of optimizers to delegate to
filters (list[Callable[[str, array], bool]): A list of predicates that
should be one less than the provided optimizers.
"""
def __init__(self, optimizers, filters: list = []):
super().__init__()
self._state = {}
if len(filters) != len(optimizers) - 1:
raise ValueError(
f"Given {len(filters)} filters but {len(optimizers)-1} needed."
)
self.optimizers = optimizers
self.filters = filters + [lambda *args, **kwargs: True]
def _split_dictionary(self, gradients: dict):
if len(self.optimizers) == 1:
return [gradients]
parts = [[] for _ in range(len(self.optimizers))]
flat_gradients = tree_flatten(gradients)
for k, g in flat_gradients:
for i, fn in enumerate(self.filters):
if fn(k, g):
parts[i].append((k, g))
break
return [tree_unflatten(p) for p in parts]
def init(self, parameters: dict):
for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
o.init(p)
def apply_gradients(self, gradients: dict, parameters: dict):
tree = {}
for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
tree = tree_merge(tree, o.apply_gradients(g, parameters))
return tree
@property
def state(self):
return {"states": [o.state for o in self.optimizers]}
@state.setter
def state(self, state: dict):
if "states" not in state or len(state["states"]) != len(self.optimizers):
raise ValueError("Invalid state provided")
for o, s in zip(self.optimizers, state["states"]):
o.state = s
@property
def learning_rate(self):
return self.optimizers[0].learning_rate
@learning_rate.setter
def learning_rate(self, learning_rate: Union[float, mx.array]):
for o in self.optimizers:
o.learning_rate = learning_rate
class SGD(Optimizer):
r"""The stochastic gradient descent optimizer.

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple
@@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
return tree if accumulator is None else fn(accumulator, tree)
return accumulator
def tree_merge(tree_a, tree_b, merge_fn=None):
"""Merge two Python trees in one containing the values of both. It can be
thought of as a deep dict.update method.
Args:
tree_a (Any): The first Python tree.
tree_b (Any): The second Python tree.
merge_fn (callable, optional): A function to merge leaves.
Returns:
The Python tree containing the values of both ``tree_a`` and
``tree_b``.
"""
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
tree_a = None
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
tree_b = None
if tree_a is None and tree_b is not None:
return tree_b
if tree_a is not None and tree_b is None:
return tree_a
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
TreeType = type(tree_a)
return TreeType(
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
)
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
return {
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
for k in set(tree_a.keys()) | set(tree_b.keys())
}
else:
if merge_fn is None:
raise ValueError(
(
"Trees contain elements at the same locations but no merge "
"function was provided"
)
)
return merge_fn(tree_a, tree_b)