diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2572a6d53..5969c5052 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5196,8 +5196,8 @@ void init_ops(nb::module_& m) { throw std::invalid_argument( "[broadcast_shapes] Must provide at least one shape."); - mx::Shape result; - for (size_t i = 0; i < shapes.size(); ++i) { + mx::Shape result = nb::cast(shapes[0]); + for (size_t i = 1; i < shapes.size(); ++i) { if (!nb::isinstance(shapes[i]) && !nb::isinstance(shapes[i])) throw std::invalid_argument(