fix to use c++ api

This commit is contained in:
Hyunsung Lee 2025-04-20 12:55:58 +09:00
parent 876c1986e4
commit a7a96b0ad6
3 changed files with 85 additions and 78 deletions

View File

@ -289,64 +289,3 @@ def tree_merge(tree_a, tree_b, merge_fn=None):
)
)
return merge_fn(tree_a, tree_b)
def broadcast_shapes(*shapes):
"""Broadcast shapes to the same size.
Uses the same broadcasting rules as NumPy. The size of the trailing axes
for both arrays in an operation must either be the same size or one of
them must be one.
Args:
*shapes: The shapes to be broadcast against each other.
Each shape should be a tuple or list of integers.
Returns:
A tuple of integers representing the broadcasted shape.
Raises:
ValueError: If the shapes cannot be broadcast according to broadcasting rules.
Examples:
>>> broadcast_shapes((1, 2, 3), (3,))
(1, 2, 3)
>>> broadcast_shapes((1, 2, 3), (4, 1, 3))
(4, 2, 3)
>>> broadcast_shapes((5, 1, 3), (1, 4, 3))
(5, 4, 3)
"""
if len(shapes) == 0:
raise ValueError("No shapes provided")
if len(shapes) == 1:
return shapes[0]
result = shapes[0]
for shape in shapes[1:]:
ndim1 = len(result)
ndim2 = len(shape)
ndim = max(ndim1, ndim2)
diff = abs(ndim1 - ndim2)
big = result if ndim1 > ndim2 else shape
small = shape if ndim1 > ndim2 else result
out_shape = []
for i in range(ndim - 1, diff - 1, -1):
a = big[i]
b = small[i - diff]
if a == b:
out_shape.insert(0, a)
elif a == 1 or b == 1:
out_shape.insert(0, a * b)
else:
raise ValueError(
f"Shapes {result} and {shape} cannot be broadcast together"
)
for i in range(diff - 1, -1, -1):
out_shape.insert(0, big[i])
result = tuple(out_shape)
return result

View File

@ -5189,4 +5189,72 @@ void init_ops(nb::module_& m) {
Returns:
array: The row or col contiguous output.
)pbdoc");
m.def(
"broadcast_shapes",
[](const nb::args& shapes) {
if (shapes.size() == 0) {
throw std::invalid_argument(
"broadcast_shapes expects a sequence of shapes");
}
std::vector<mx::Shape> shape_vec;
shape_vec.reserve(shapes.size());
for (size_t i = 0; i < shapes.size(); ++i) {
mx::Shape shape;
if (nb::isinstance<nb::tuple>(shapes[i])) {
nb::tuple t = nb::cast<nb::tuple>(shapes[i]);
for (size_t j = 0; j < t.size(); ++j) {
shape.push_back(nb::cast<int>(t[j]));
}
} else if (nb::isinstance<nb::list>(shapes[i])) {
nb::list l = nb::cast<nb::list>(shapes[i]);
for (size_t j = 0; j < l.size(); ++j) {
shape.push_back(nb::cast<int>(l[j]));
}
} else {
throw std::invalid_argument(
"broadcast_shapes expects a sequence of shapes");
}
shape_vec.push_back(shape);
}
if (shape_vec.empty()) {
return nb::tuple();
}
mx::Shape result = shape_vec[0];
for (size_t i = 1; i < shape_vec.size(); ++i) {
result = mx::broadcast_shapes(result, shape_vec[i]);
}
auto py_list = nb::cast(result);
return nb::tuple(py_list);
},
nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Sequence[int]"),
R"pbdoc(
Broadcast shapes.
Returns the shape that results from broadcasting the supplied array shapes
against each other.
Args:
*shapes (Sequence[int]): The shapes to broadcast.
Returns:
tuple: The broadcasted shape.
Raises:
ValueError: If the shapes cannot be broadcast.
Example:
>>> mx.broadcast_shapes((1,), (3, 1))
(3, 1)
>>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,))
(5, 6, 7)
>>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1))
(5, 3, 4)
)pbdoc");
}

View File

@ -1,44 +1,44 @@
# Copyright © 2025 Apple Inc.
import mlx.utils
import mlx.core
import mlx_tests
class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self):
# Basic broadcasting
self.assertEqual(mlx.utils.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
self.assertEqual(mlx.utils.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6))
self.assertEqual(mlx.utils.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4))
self.assertEqual(mlx.core.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
self.assertEqual(mlx.core.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6))
self.assertEqual(mlx.core.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4))
# Multiple arguments
self.assertEqual(mlx.utils.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8))
self.assertEqual(mlx.core.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8))
self.assertEqual(
mlx.utils.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5)
mlx.core.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5)
)
# Same shapes
self.assertEqual(mlx.utils.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5))
self.assertEqual(mlx.core.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5))
# Single argument
self.assertEqual(mlx.utils.broadcast_shapes((2, 3)), (2, 3))
self.assertEqual(mlx.core.broadcast_shapes((2, 3)), (2, 3))
# Empty shapes
self.assertEqual(mlx.utils.broadcast_shapes((), ()), ())
self.assertEqual(mlx.utils.broadcast_shapes((), (1,)), (1,))
self.assertEqual(mlx.utils.broadcast_shapes((1,), ()), (1,))
self.assertEqual(mlx.core.broadcast_shapes((), ()), ())
self.assertEqual(mlx.core.broadcast_shapes((), (1,)), (1,))
self.assertEqual(mlx.core.broadcast_shapes((1,), ()), (1,))
# Broadcasting with zeroes
self.assertEqual(mlx.utils.broadcast_shapes((0,), (0,)), (0,))
self.assertEqual(mlx.utils.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5))
self.assertEqual(mlx.utils.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0))
self.assertEqual(mlx.core.broadcast_shapes((0,), (0,)), (0,))
self.assertEqual(mlx.core.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5))
self.assertEqual(mlx.core.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0))
# Error cases
with self.assertRaises(ValueError):
mlx.utils.broadcast_shapes((3, 4), (4, 3))
mlx.core.broadcast_shapes((3, 4), (4, 3))
with self.assertRaises(ValueError):
mlx.utils.broadcast_shapes((2, 3, 4), (2, 5, 4))
mlx.core.broadcast_shapes((2, 3, 4), (2, 5, 4))
with self.assertRaises(ValueError):
mlx.utils.broadcast_shapes()
mlx.core.broadcast_shapes()