mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add a multi optimizer (#1916)
This commit is contained in:
committed by
GitHub
parent
a0737273d3
commit
9680f72cca
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
from collections import defaultdict
|
||||
from itertools import zip_longest
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
|
||||
@@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
|
||||
return tree if accumulator is None else fn(accumulator, tree)
|
||||
|
||||
return accumulator
|
||||
|
||||
|
||||
def tree_merge(tree_a, tree_b, merge_fn=None):
|
||||
"""Merge two Python trees in one containing the values of both. It can be
|
||||
thought of as a deep dict.update method.
|
||||
|
||||
Args:
|
||||
tree_a (Any): The first Python tree.
|
||||
tree_b (Any): The second Python tree.
|
||||
merge_fn (callable, optional): A function to merge leaves.
|
||||
|
||||
Returns:
|
||||
The Python tree containing the values of both ``tree_a`` and
|
||||
``tree_b``.
|
||||
"""
|
||||
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
|
||||
tree_a = None
|
||||
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
|
||||
tree_b = None
|
||||
if tree_a is None and tree_b is not None:
|
||||
return tree_b
|
||||
if tree_a is not None and tree_b is None:
|
||||
return tree_a
|
||||
|
||||
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
|
||||
TreeType = type(tree_a)
|
||||
return TreeType(
|
||||
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
|
||||
)
|
||||
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
|
||||
return {
|
||||
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
|
||||
for k in set(tree_a.keys()) | set(tree_b.keys())
|
||||
}
|
||||
else:
|
||||
if merge_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
"Trees contain elements at the same locations but no merge "
|
||||
"function was provided"
|
||||
)
|
||||
)
|
||||
return merge_fn(tree_a, tree_b)
|
||||
|
||||
Reference in New Issue
Block a user