diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 18151daad..0c91f0c0e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3198,12 +3198,9 @@ void init_ops(py::module_& m) { "tensordot", [](const array& a, const array& b, - const std::variant>>& - dims, + const std::variant>>& dims, StreamOrDevice s) { - if (std::holds_alternative(dims)) { - return tensordot(a, b, 2, s); - } else if (auto pv = std::get_if(&dims); pv) { + if (auto pv = std::get_if(&dims); pv) { return tensordot(a, b, *pv, s); } else { return tensordot(