fix to use c++ api

This commit is contained in:
Hyunsung Lee
2025-04-20 12:55:58 +09:00
parent 876c1986e4
commit a7a96b0ad6
3 changed files with 85 additions and 78 deletions

View File

@@ -5189,4 +5189,72 @@ void init_ops(nb::module_& m) {
Returns:
array: The row or col contiguous output.
)pbdoc");
m.def(
"broadcast_shapes",
[](const nb::args& shapes) {
if (shapes.size() == 0) {
throw std::invalid_argument(
"broadcast_shapes expects a sequence of shapes");
}
std::vector<mx::Shape> shape_vec;
shape_vec.reserve(shapes.size());
for (size_t i = 0; i < shapes.size(); ++i) {
mx::Shape shape;
if (nb::isinstance<nb::tuple>(shapes[i])) {
nb::tuple t = nb::cast<nb::tuple>(shapes[i]);
for (size_t j = 0; j < t.size(); ++j) {
shape.push_back(nb::cast<int>(t[j]));
}
} else if (nb::isinstance<nb::list>(shapes[i])) {
nb::list l = nb::cast<nb::list>(shapes[i]);
for (size_t j = 0; j < l.size(); ++j) {
shape.push_back(nb::cast<int>(l[j]));
}
} else {
throw std::invalid_argument(
"broadcast_shapes expects a sequence of shapes");
}
shape_vec.push_back(shape);
}
if (shape_vec.empty()) {
return nb::tuple();
}
mx::Shape result = shape_vec[0];
for (size_t i = 1; i < shape_vec.size(); ++i) {
result = mx::broadcast_shapes(result, shape_vec[i]);
}
auto py_list = nb::cast(result);
return nb::tuple(py_list);
},
nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Sequence[int]"),
R"pbdoc(
Broadcast shapes.
Returns the shape that results from broadcasting the supplied array shapes
against each other.
Args:
*shapes (Sequence[int]): The shapes to broadcast.
Returns:
tuple: The broadcasted shape.
Raises:
ValueError: If the shapes cannot be broadcast.
Example:
>>> mx.broadcast_shapes((1,), (3, 1))
(3, 1)
>>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,))
(5, 6, 7)
>>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1))
(5, 3, 4)
)pbdoc");
}