diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 9a76188fa..18151daad 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3198,7 +3198,8 @@ void init_ops(py::module_& m) { "tensordot", [](const array& a, const array& b, - const IntOrIntVec& dims, + const std::variant>>& + dims, StreamOrDevice s) { if (std::holds_alternative(dims)) { return tensordot(a, b, 2, s); diff --git a/python/src/utils.h b/python/src/utils.h index 86da9f6db..5ac878979 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -15,8 +15,6 @@ namespace py = pybind11; using namespace mlx::core; using IntOrVec = std::variant>; -using IntOrIntVec = - std::variant>>; using ScalarOrArray = std:: variant, py::object>; static constexpr std::monostate none{};