From 5f62e209f4b407268aa18d3d6aa4abf6048d253a Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Sat, 19 Apr 2025 15:37:19 +0900 Subject: [PATCH] Add broadcast_shapes in python API --- python/mlx/utils.py | 61 ++++++++++++++++++++++++++++++++++ python/tests/test_broadcast.py | 44 ++++++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 python/tests/test_broadcast.py diff --git a/python/mlx/utils.py b/python/mlx/utils.py index b7173deb7..de1781da3 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -289,3 +289,64 @@ def tree_merge(tree_a, tree_b, merge_fn=None): ) ) return merge_fn(tree_a, tree_b) + + +def broadcast_shapes(*shapes): + """Broadcast shapes to the same size. + + Uses the same broadcasting rules as NumPy. The size of the trailing axes + for both arrays in an operation must either be the same size or one of + them must be one. + + Args: + *shapes: The shapes to be broadcast against each other. + Each shape should be a tuple or list of integers. + + Returns: + A tuple of integers representing the broadcasted shape. + + Raises: + ValueError: If the shapes cannot be broadcast according to broadcasting rules. + + Examples: + >>> broadcast_shapes((1, 2, 3), (3,)) + (1, 2, 3) + >>> broadcast_shapes((1, 2, 3), (4, 1, 3)) + (4, 2, 3) + >>> broadcast_shapes((5, 1, 3), (1, 4, 3)) + (5, 4, 3) + """ + if len(shapes) == 0: + raise ValueError("No shapes provided") + + if len(shapes) == 1: + return shapes[0] + + result = shapes[0] + for shape in shapes[1:]: + ndim1 = len(result) + ndim2 = len(shape) + ndim = max(ndim1, ndim2) + + diff = abs(ndim1 - ndim2) + big = result if ndim1 > ndim2 else shape + small = shape if ndim1 > ndim2 else result + + out_shape = [] + for i in range(ndim - 1, diff - 1, -1): + a = big[i] + b = small[i - diff] + + if a == b: + out_shape.insert(0, a) + elif a == 1 or b == 1: + out_shape.insert(0, a * b) + else: + raise ValueError( + f"Shapes {result} and {shape} cannot be broadcast together" + ) + + for i in range(diff - 1, -1, -1): + out_shape.insert(0, big[i]) + result = tuple(out_shape) + return result diff --git a/python/tests/test_broadcast.py b/python/tests/test_broadcast.py new file mode 100644 index 000000000..6b72dd3e9 --- /dev/null +++ b/python/tests/test_broadcast.py @@ -0,0 +1,44 @@ +# Copyright © 2023 Apple Inc. + +import mlx.utils +import mlx_tests + + +class TestBroadcast(mlx_tests.MLXTestCase): + def test_broadcast_shapes(self): + # Basic broadcasting + self.assertEqual(mlx.utils.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) + self.assertEqual(mlx.utils.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6)) + self.assertEqual(mlx.utils.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4)) + + # Multiple arguments + self.assertEqual(mlx.utils.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8)) + self.assertEqual( + mlx.utils.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5) + ) + + # Same shapes + self.assertEqual(mlx.utils.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5)) + + # Single argument + self.assertEqual(mlx.utils.broadcast_shapes((2, 3)), (2, 3)) + + # Empty shapes + self.assertEqual(mlx.utils.broadcast_shapes((), ()), ()) + self.assertEqual(mlx.utils.broadcast_shapes((), (1,)), (1,)) + self.assertEqual(mlx.utils.broadcast_shapes((1,), ()), (1,)) + + # Broadcasting with zeroes + self.assertEqual(mlx.utils.broadcast_shapes((0,), (0,)), (0,)) + self.assertEqual(mlx.utils.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5)) + self.assertEqual(mlx.utils.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0)) + + # Error cases + with self.assertRaises(ValueError): + mlx.utils.broadcast_shapes((3, 4), (4, 3)) + + with self.assertRaises(ValueError): + mlx.utils.broadcast_shapes((2, 3, 4), (2, 5, 4)) + + with self.assertRaises(ValueError): + mlx.utils.broadcast_shapes()