From 5e8712cf6fdeac612f4daacf663af8f8ea94f0e4 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Tue, 22 Apr 2025 05:47:07 +0900 Subject: [PATCH] Fix upon review Update python/src/ops.cpp Co-authored-by: Awni Hannun --- python/src/ops.cpp | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 19d089446..2572a6d53 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5194,26 +5194,18 @@ void init_ops(nb::module_& m) { [](const nb::args& shapes) { if (shapes.size() == 0) throw std::invalid_argument( - "broadcast_shapes expects a sequence of shapes"); - - std::vector shape_vec; - shape_vec.reserve(shapes.size()); + "[broadcast_shapes] Must provide at least one shape."); + mx::Shape result; for (size_t i = 0; i < shapes.size(); ++i) { if (!nb::isinstance(shapes[i]) && !nb::isinstance(shapes[i])) throw std::invalid_argument( - "broadcast_shapes expects a sequence of shapes (tuple or list of ints)"); - - shape_vec.push_back(nb::cast(shapes[i])); + "[broadcast_shapes] Expects a sequence of shapes (tuple or list of ints)."); + result = mx::broadcast_shapes(result, nb::cast(shapes[i])); } - mx::Shape result = shape_vec[0]; - for (size_t i = 1; i < shape_vec.size(); ++i) - result = mx::broadcast_shapes(result, shape_vec[i]); - - auto py_list = nb::cast(result); - return nb::tuple(py_list); + return nb::tuple(nb::cast(result)); }, nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Tuple[int]"), R"pbdoc(