diff --git a/python/mlx/utils.py b/python/mlx/utils.py index de1781da3f..b7173deb71 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -289,64 +289,3 @@ def tree_merge(tree_a, tree_b, merge_fn=None): ) ) return merge_fn(tree_a, tree_b) - - -def broadcast_shapes(*shapes): - """Broadcast shapes to the same size. - - Uses the same broadcasting rules as NumPy. The size of the trailing axes - for both arrays in an operation must either be the same size or one of - them must be one. - - Args: - *shapes: The shapes to be broadcast against each other. - Each shape should be a tuple or list of integers. - - Returns: - A tuple of integers representing the broadcasted shape. - - Raises: - ValueError: If the shapes cannot be broadcast according to broadcasting rules. - - Examples: - >>> broadcast_shapes((1, 2, 3), (3,)) - (1, 2, 3) - >>> broadcast_shapes((1, 2, 3), (4, 1, 3)) - (4, 2, 3) - >>> broadcast_shapes((5, 1, 3), (1, 4, 3)) - (5, 4, 3) - """ - if len(shapes) == 0: - raise ValueError("No shapes provided") - - if len(shapes) == 1: - return shapes[0] - - result = shapes[0] - for shape in shapes[1:]: - ndim1 = len(result) - ndim2 = len(shape) - ndim = max(ndim1, ndim2) - - diff = abs(ndim1 - ndim2) - big = result if ndim1 > ndim2 else shape - small = shape if ndim1 > ndim2 else result - - out_shape = [] - for i in range(ndim - 1, diff - 1, -1): - a = big[i] - b = small[i - diff] - - if a == b: - out_shape.insert(0, a) - elif a == 1 or b == 1: - out_shape.insert(0, a * b) - else: - raise ValueError( - f"Shapes {result} and {shape} cannot be broadcast together" - ) - - for i in range(diff - 1, -1, -1): - out_shape.insert(0, big[i]) - result = tuple(out_shape) - return result diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f98aa80aac..8a70cba054 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5189,4 +5189,72 @@ void init_ops(nb::module_& m) { Returns: array: The row or col contiguous output. )pbdoc"); + m.def( + "broadcast_shapes", + [](const nb::args& shapes) { + if (shapes.size() == 0) { + throw std::invalid_argument( + "broadcast_shapes expects a sequence of shapes"); + } + + std::vector shape_vec; + shape_vec.reserve(shapes.size()); + + for (size_t i = 0; i < shapes.size(); ++i) { + mx::Shape shape; + + if (nb::isinstance(shapes[i])) { + nb::tuple t = nb::cast(shapes[i]); + for (size_t j = 0; j < t.size(); ++j) { + shape.push_back(nb::cast(t[j])); + } + } else if (nb::isinstance(shapes[i])) { + nb::list l = nb::cast(shapes[i]); + for (size_t j = 0; j < l.size(); ++j) { + shape.push_back(nb::cast(l[j])); + } + } else { + throw std::invalid_argument( + "broadcast_shapes expects a sequence of shapes"); + } + + shape_vec.push_back(shape); + } + + if (shape_vec.empty()) { + return nb::tuple(); + } + + mx::Shape result = shape_vec[0]; + for (size_t i = 1; i < shape_vec.size(); ++i) { + result = mx::broadcast_shapes(result, shape_vec[i]); + } + + auto py_list = nb::cast(result); + return nb::tuple(py_list); + }, + nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Sequence[int]"), + R"pbdoc( + Broadcast shapes. + + Returns the shape that results from broadcasting the supplied array shapes + against each other. + + Args: + *shapes (Sequence[int]): The shapes to broadcast. + + Returns: + tuple: The broadcasted shape. + + Raises: + ValueError: If the shapes cannot be broadcast. + + Example: + >>> mx.broadcast_shapes((1,), (3, 1)) + (3, 1) + >>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,)) + (5, 6, 7) + >>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1)) + (5, 3, 4) + )pbdoc"); } diff --git a/python/tests/test_broadcast.py b/python/tests/test_broadcast.py index 7b11dcb5f5..216a4116f2 100644 --- a/python/tests/test_broadcast.py +++ b/python/tests/test_broadcast.py @@ -1,44 +1,44 @@ # Copyright © 2025 Apple Inc. -import mlx.utils +import mlx.core import mlx_tests class TestBroadcast(mlx_tests.MLXTestCase): def test_broadcast_shapes(self): # Basic broadcasting - self.assertEqual(mlx.utils.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) - self.assertEqual(mlx.utils.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6)) - self.assertEqual(mlx.utils.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4)) + self.assertEqual(mlx.core.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) + self.assertEqual(mlx.core.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6)) + self.assertEqual(mlx.core.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4)) # Multiple arguments - self.assertEqual(mlx.utils.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8)) + self.assertEqual(mlx.core.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8)) self.assertEqual( - mlx.utils.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5) + mlx.core.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5) ) # Same shapes - self.assertEqual(mlx.utils.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5)) + self.assertEqual(mlx.core.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5)) # Single argument - self.assertEqual(mlx.utils.broadcast_shapes((2, 3)), (2, 3)) + self.assertEqual(mlx.core.broadcast_shapes((2, 3)), (2, 3)) # Empty shapes - self.assertEqual(mlx.utils.broadcast_shapes((), ()), ()) - self.assertEqual(mlx.utils.broadcast_shapes((), (1,)), (1,)) - self.assertEqual(mlx.utils.broadcast_shapes((1,), ()), (1,)) + self.assertEqual(mlx.core.broadcast_shapes((), ()), ()) + self.assertEqual(mlx.core.broadcast_shapes((), (1,)), (1,)) + self.assertEqual(mlx.core.broadcast_shapes((1,), ()), (1,)) # Broadcasting with zeroes - self.assertEqual(mlx.utils.broadcast_shapes((0,), (0,)), (0,)) - self.assertEqual(mlx.utils.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5)) - self.assertEqual(mlx.utils.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0)) + self.assertEqual(mlx.core.broadcast_shapes((0,), (0,)), (0,)) + self.assertEqual(mlx.core.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5)) + self.assertEqual(mlx.core.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0)) # Error cases with self.assertRaises(ValueError): - mlx.utils.broadcast_shapes((3, 4), (4, 3)) + mlx.core.broadcast_shapes((3, 4), (4, 3)) with self.assertRaises(ValueError): - mlx.utils.broadcast_shapes((2, 3, 4), (2, 5, 4)) + mlx.core.broadcast_shapes((2, 3, 4), (2, 5, 4)) with self.assertRaises(ValueError): - mlx.utils.broadcast_shapes() + mlx.core.broadcast_shapes()