diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 513efca1a..e858f7e8f 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5201,19 +5201,12 @@ void init_ops(nb::module_& m) { for (size_t i = 0; i < shapes.size(); ++i) { mx::Shape shape; - if (nb::isinstance(shapes[i])) { - nb::tuple t = nb::cast(shapes[i]); - for (size_t j = 0; j < t.size(); ++j) { - shape.push_back(nb::cast(t[j])); - } - } else if (nb::isinstance(shapes[i])) { - nb::list l = nb::cast(shapes[i]); - for (size_t j = 0; j < l.size(); ++j) { - shape.push_back(nb::cast(l[j])); - } + if (nb::isinstance(shapes[i]) || + nb::isinstance(shapes[i])) { + shape = nb::cast(shapes[i]); } else { throw std::invalid_argument( - "broadcast_shapes expects a sequence of shapes"); + "broadcast_shapes expects a sequence of shapes (tuple or list of ints)"); } shape_vec.push_back(shape); diff --git a/python/tests/test_broadcast.py b/python/tests/test_broadcast.py deleted file mode 100644 index 216a4116f..000000000 --- a/python/tests/test_broadcast.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright © 2025 Apple Inc. - -import mlx.core -import mlx_tests - - -class TestBroadcast(mlx_tests.MLXTestCase): - def test_broadcast_shapes(self): - # Basic broadcasting - self.assertEqual(mlx.core.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) - self.assertEqual(mlx.core.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6)) - self.assertEqual(mlx.core.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4)) - - # Multiple arguments - self.assertEqual(mlx.core.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8)) - self.assertEqual( - mlx.core.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5) - ) - - # Same shapes - self.assertEqual(mlx.core.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5)) - - # Single argument - self.assertEqual(mlx.core.broadcast_shapes((2, 3)), (2, 3)) - - # Empty shapes - self.assertEqual(mlx.core.broadcast_shapes((), ()), ()) - self.assertEqual(mlx.core.broadcast_shapes((), (1,)), (1,)) - self.assertEqual(mlx.core.broadcast_shapes((1,), ()), (1,)) - - # Broadcasting with zeroes - self.assertEqual(mlx.core.broadcast_shapes((0,), (0,)), (0,)) - self.assertEqual(mlx.core.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5)) - self.assertEqual(mlx.core.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0)) - - # Error cases - with self.assertRaises(ValueError): - mlx.core.broadcast_shapes((3, 4), (4, 3)) - - with self.assertRaises(ValueError): - mlx.core.broadcast_shapes((2, 3, 4), (2, 5, 4)) - - with self.assertRaises(ValueError): - mlx.core.broadcast_shapes() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4fcb31f18..5008abb3f 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2935,5 +2935,45 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) +class TestBroadcast(mlx_tests.MLXTestCase): + def test_broadcast_shapes(self): + # Basic broadcasting + self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) + self.assertEqual(mx.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6)) + self.assertEqual(mx.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4)) + + # Multiple arguments + self.assertEqual(mx.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8)) + self.assertEqual( + mx.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5) + ) + + # Same shapes + self.assertEqual(mx.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5)) + + # Single argument + self.assertEqual(mx.broadcast_shapes((2, 3)), (2, 3)) + + # Empty shapes + self.assertEqual(mx.broadcast_shapes((), ()), ()) + self.assertEqual(mx.broadcast_shapes((), (1,)), (1,)) + self.assertEqual(mx.broadcast_shapes((1,), ()), (1,)) + + # Broadcasting with zeroes + self.assertEqual(mx.broadcast_shapes((0,), (0,)), (0,)) + self.assertEqual(mx.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5)) + self.assertEqual(mx.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0)) + + # Error cases + with self.assertRaises(ValueError): + mx.broadcast_shapes((3, 4), (4, 3)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes((2, 3, 4), (2, 5, 4)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes() + + if __name__ == "__main__": unittest.main()