From 92444f393b0c5e525a47c73d5689c47fd7f4cba5 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Sun, 20 Apr 2025 12:58:26 +0900 Subject: [PATCH] nit --- python/src/ops.cpp | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8a70cba05..2c882490e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5192,17 +5192,15 @@ void init_ops(nb::module_& m) { m.def( "broadcast_shapes", [](const nb::args& shapes) { - if (shapes.size() == 0) { + if (shapes.size() == 0) throw std::invalid_argument( "broadcast_shapes expects a sequence of shapes"); - } std::vector shape_vec; shape_vec.reserve(shapes.size()); for (size_t i = 0; i < shapes.size(); ++i) { mx::Shape shape; - if (nb::isinstance(shapes[i])) { nb::tuple t = nb::cast(shapes[i]); for (size_t j = 0; j < t.size(); ++j) { @@ -5221,14 +5219,9 @@ void init_ops(nb::module_& m) { shape_vec.push_back(shape); } - if (shape_vec.empty()) { - return nb::tuple(); - } - mx::Shape result = shape_vec[0]; - for (size_t i = 1; i < shape_vec.size(); ++i) { + for (size_t i = 1; i < shape_vec.size(); ++i) result = mx::broadcast_shapes(result, shape_vec[i]); - } auto py_list = nb::cast(result); return nb::tuple(py_list);