fix to use c++ api

This commit is contained in:
Hyunsung Lee
2025-04-20 12:55:58 +09:00
parent 876c1986e4
commit a7a96b0ad6
3 changed files with 85 additions and 78 deletions

View File

@@ -289,64 +289,3 @@ def tree_merge(tree_a, tree_b, merge_fn=None):
)
)
return merge_fn(tree_a, tree_b)
def broadcast_shapes(*shapes):
"""Broadcast shapes to the same size.
Uses the same broadcasting rules as NumPy. The size of the trailing axes
for both arrays in an operation must either be the same size or one of
them must be one.
Args:
*shapes: The shapes to be broadcast against each other.
Each shape should be a tuple or list of integers.
Returns:
A tuple of integers representing the broadcasted shape.
Raises:
ValueError: If the shapes cannot be broadcast according to broadcasting rules.
Examples:
>>> broadcast_shapes((1, 2, 3), (3,))
(1, 2, 3)
>>> broadcast_shapes((1, 2, 3), (4, 1, 3))
(4, 2, 3)
>>> broadcast_shapes((5, 1, 3), (1, 4, 3))
(5, 4, 3)
"""
if len(shapes) == 0:
raise ValueError("No shapes provided")
if len(shapes) == 1:
return shapes[0]
result = shapes[0]
for shape in shapes[1:]:
ndim1 = len(result)
ndim2 = len(shape)
ndim = max(ndim1, ndim2)
diff = abs(ndim1 - ndim2)
big = result if ndim1 > ndim2 else shape
small = shape if ndim1 > ndim2 else result
out_shape = []
for i in range(ndim - 1, diff - 1, -1):
a = big[i]
b = small[i - diff]
if a == b:
out_shape.insert(0, a)
elif a == 1 or b == 1:
out_shape.insert(0, a * b)
else:
raise ValueError(
f"Shapes {result} and {shape} cannot be broadcast together"
)
for i in range(diff - 1, -1, -1):
out_shape.insert(0, big[i])
result = tuple(out_shape)
return result