diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 19d089446..2572a6d53 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5194,26 +5194,18 @@ void init_ops(nb::module_& m) { [](const nb::args& shapes) { if (shapes.size() == 0) throw std::invalid_argument( - "broadcast_shapes expects a sequence of shapes"); - - std::vector shape_vec; - shape_vec.reserve(shapes.size()); + "[broadcast_shapes] Must provide at least one shape."); + mx::Shape result; for (size_t i = 0; i < shapes.size(); ++i) { if (!nb::isinstance(shapes[i]) && !nb::isinstance(shapes[i])) throw std::invalid_argument( - "broadcast_shapes expects a sequence of shapes (tuple or list of ints)"); - - shape_vec.push_back(nb::cast(shapes[i])); + "[broadcast_shapes] Expects a sequence of shapes (tuple or list of ints)."); + result = mx::broadcast_shapes(result, nb::cast(shapes[i])); } - mx::Shape result = shape_vec[0]; - for (size_t i = 1; i < shape_vec.size(); ++i) - result = mx::broadcast_shapes(result, shape_vec[i]); - - auto py_list = nb::cast(result); - return nb::tuple(py_list); + return nb::tuple(nb::cast(result)); }, nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Tuple[int]"), R"pbdoc(