Add a multi optimizer (#1916)

This commit is contained in:
Angelos Katharopoulos
2025-03-04 13:16:35 -08:00
committed by GitHub
parent a0737273d3
commit 9680f72cca
5 changed files with 168 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple
@@ -244,3 +245,46 @@ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
return tree if accumulator is None else fn(accumulator, tree)
return accumulator
def tree_merge(tree_a, tree_b, merge_fn=None):
"""Merge two Python trees in one containing the values of both. It can be
thought of as a deep dict.update method.
Args:
tree_a (Any): The first Python tree.
tree_b (Any): The second Python tree.
merge_fn (callable, optional): A function to merge leaves.
Returns:
The Python tree containing the values of both ``tree_a`` and
``tree_b``.
"""
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
tree_a = None
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
tree_b = None
if tree_a is None and tree_b is not None:
return tree_b
if tree_a is not None and tree_b is None:
return tree_a
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
TreeType = type(tree_a)
return TreeType(
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
)
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
return {
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
for k in set(tree_a.keys()) | set(tree_b.keys())
}
else:
if merge_fn is None:
raise ValueError(
(
"Trees contain elements at the same locations but no merge "
"function was provided"
)
)
return merge_fn(tree_a, tree_b)