Add broadcast_shapes in python API

This commit is contained in:
Hyunsung Lee 2025-04-19 15:37:19 +09:00
parent 5f04c0f818
commit 5f62e209f4
2 changed files with 105 additions and 0 deletions

View File

@ -289,3 +289,64 @@ def tree_merge(tree_a, tree_b, merge_fn=None):
)
)
return merge_fn(tree_a, tree_b)
def broadcast_shapes(*shapes):
"""Broadcast shapes to the same size.
Uses the same broadcasting rules as NumPy. The size of the trailing axes
for both arrays in an operation must either be the same size or one of
them must be one.
Args:
*shapes: The shapes to be broadcast against each other.
Each shape should be a tuple or list of integers.
Returns:
A tuple of integers representing the broadcasted shape.
Raises:
ValueError: If the shapes cannot be broadcast according to broadcasting rules.
Examples:
>>> broadcast_shapes((1, 2, 3), (3,))
(1, 2, 3)
>>> broadcast_shapes((1, 2, 3), (4, 1, 3))
(4, 2, 3)
>>> broadcast_shapes((5, 1, 3), (1, 4, 3))
(5, 4, 3)
"""
if len(shapes) == 0:
raise ValueError("No shapes provided")
if len(shapes) == 1:
return shapes[0]
result = shapes[0]
for shape in shapes[1:]:
ndim1 = len(result)
ndim2 = len(shape)
ndim = max(ndim1, ndim2)
diff = abs(ndim1 - ndim2)
big = result if ndim1 > ndim2 else shape
small = shape if ndim1 > ndim2 else result
out_shape = []
for i in range(ndim - 1, diff - 1, -1):
a = big[i]
b = small[i - diff]
if a == b:
out_shape.insert(0, a)
elif a == 1 or b == 1:
out_shape.insert(0, a * b)
else:
raise ValueError(
f"Shapes {result} and {shape} cannot be broadcast together"
)
for i in range(diff - 1, -1, -1):
out_shape.insert(0, big[i])
result = tuple(out_shape)
return result

View File

@ -0,0 +1,44 @@
# Copyright © 2023 Apple Inc.
import mlx.utils
import mlx_tests
class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self):
# Basic broadcasting
self.assertEqual(mlx.utils.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
self.assertEqual(mlx.utils.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6))
self.assertEqual(mlx.utils.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4))
# Multiple arguments
self.assertEqual(mlx.utils.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8))
self.assertEqual(
mlx.utils.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5)
)
# Same shapes
self.assertEqual(mlx.utils.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5))
# Single argument
self.assertEqual(mlx.utils.broadcast_shapes((2, 3)), (2, 3))
# Empty shapes
self.assertEqual(mlx.utils.broadcast_shapes((), ()), ())
self.assertEqual(mlx.utils.broadcast_shapes((), (1,)), (1,))
self.assertEqual(mlx.utils.broadcast_shapes((1,), ()), (1,))
# Broadcasting with zeroes
self.assertEqual(mlx.utils.broadcast_shapes((0,), (0,)), (0,))
self.assertEqual(mlx.utils.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5))
self.assertEqual(mlx.utils.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0))
# Error cases
with self.assertRaises(ValueError):
mlx.utils.broadcast_shapes((3, 4), (4, 3))
with self.assertRaises(ValueError):
mlx.utils.broadcast_shapes((2, 3, 4), (2, 5, 4))
with self.assertRaises(ValueError):
mlx.utils.broadcast_shapes()