From 383644524150f39c96008ba6645201ec23bb44d6 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Wed, 23 Apr 2025 10:57:39 +0900 Subject: [PATCH] Add broadcast_shapes in python API (#2091) --- python/src/ops.cpp | 42 ++++++++++++++++++++++++++++++++++++++++ python/tests/test_ops.py | 40 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f98aa80aa..5969c5052 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5189,4 +5189,46 @@ void init_ops(nb::module_& m) { Returns: array: The row or col contiguous output. )pbdoc"); + m.def( + "broadcast_shapes", + [](const nb::args& shapes) { + if (shapes.size() == 0) + throw std::invalid_argument( + "[broadcast_shapes] Must provide at least one shape."); + + mx::Shape result = nb::cast(shapes[0]); + for (size_t i = 1; i < shapes.size(); ++i) { + if (!nb::isinstance(shapes[i]) && + !nb::isinstance(shapes[i])) + throw std::invalid_argument( + "[broadcast_shapes] Expects a sequence of shapes (tuple or list of ints)."); + result = mx::broadcast_shapes(result, nb::cast(shapes[i])); + } + + return nb::tuple(nb::cast(result)); + }, + nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Tuple[int]"), + R"pbdoc( + Broadcast shapes. + + Returns the shape that results from broadcasting the supplied array shapes + against each other. + + Args: + *shapes (Sequence[int]): The shapes to broadcast. + + Returns: + tuple: The broadcasted shape. + + Raises: + ValueError: If the shapes cannot be broadcast. + + Example: + >>> mx.broadcast_shapes((1,), (3, 1)) + (3, 1) + >>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,)) + (5, 6, 7) + >>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1)) + (5, 3, 4) + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d0e52eab2..47fec3167 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3043,5 +3043,45 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) +class TestBroadcast(mlx_tests.MLXTestCase): + def test_broadcast_shapes(self): + # Basic broadcasting + self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) + self.assertEqual(mx.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6)) + self.assertEqual(mx.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4)) + + # Multiple arguments + self.assertEqual(mx.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8)) + self.assertEqual( + mx.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5) + ) + + # Same shapes + self.assertEqual(mx.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5)) + + # Single argument + self.assertEqual(mx.broadcast_shapes((2, 3)), (2, 3)) + + # Empty shapes + self.assertEqual(mx.broadcast_shapes((), ()), ()) + self.assertEqual(mx.broadcast_shapes((), (1,)), (1,)) + self.assertEqual(mx.broadcast_shapes((1,), ()), (1,)) + + # Broadcasting with zeroes + self.assertEqual(mx.broadcast_shapes((0,), (0,)), (0,)) + self.assertEqual(mx.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5)) + self.assertEqual(mx.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0)) + + # Error cases + with self.assertRaises(ValueError): + mx.broadcast_shapes((3, 4), (4, 3)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes((2, 3, 4), (2, 5, 4)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes() + + if __name__ == "__main__": unittest.main()