From 283a136c6446d3e26cbe5efc655f96663e74af54 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 10 Jun 2025 10:54:53 -0700 Subject: [PATCH] rebase --- mlx/CMakeLists.txt | 3 +++ mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/cuda.cpp | 11 +++++++++++ mlx/backend/cuda/cuda.h | 10 ++++++++++ mlx/backend/cuda/no_cuda.cpp | 11 +++++++++++ mlx/mlx.h | 1 + python/src/array.cpp | 10 +++++----- python/src/device.cpp | 5 +++++ 8 files changed, 47 insertions(+), 5 deletions(-) create mode 100644 mlx/backend/cuda/cuda.cpp create mode 100644 mlx/backend/cuda/cuda.h create mode 100644 mlx/backend/cuda/no_cuda.cpp diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index ce921b276..7aa648533 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -55,6 +55,9 @@ endif() if(MLX_BUILD_CUDA) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda) +else() + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp) endif() if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7ffbcb2d3..9d9657e1f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu + ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu diff --git a/mlx/backend/cuda/cuda.cpp b/mlx/backend/cuda/cuda.cpp new file mode 100644 index 000000000..ceb4d7dfe --- /dev/null +++ b/mlx/backend/cuda/cuda.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cuda.h" + +namespace mlx::core::cu { + +bool is_available() { + return true; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/cuda.h b/mlx/backend/cuda/cuda.h new file mode 100644 index 000000000..2c6a5c724 --- /dev/null +++ b/mlx/backend/cuda/cuda.h @@ -0,0 +1,10 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cu { + +/* Check if the CUDA backend is available. */ +bool is_available(); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/no_cuda.cpp b/mlx/backend/cuda/no_cuda.cpp new file mode 100644 index 000000000..8a394c9e3 --- /dev/null +++ b/mlx/backend/cuda/no_cuda.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cuda.h" + +namespace mlx::core::cu { + +bool is_available() { + return false; +} + +} // namespace mlx::core::cu diff --git a/mlx/mlx.h b/mlx/mlx.h index cef8d806d..de3ee392a 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/metal/metal.h" #include "mlx/compile.h" #include "mlx/device.h" diff --git a/python/src/array.cpp b/python/src/array.cpp index 5ba0aaedc..25889d775 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -17,10 +17,7 @@ #include "python/src/indexing.h" #include "python/src/utils.h" -#include "mlx/device.h" -#include "mlx/ops.h" -#include "mlx/transforms.h" -#include "mlx/utils.h" +#include "mlx/mlx.h" namespace mx = mlx::core; namespace nb = nanobind; @@ -461,9 +458,12 @@ void init_array(nb::module_& m) { .def( "__dlpack_device__", [](const mx::array& a) { + // See + // https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74 if (mx::metal::is_available()) { - // Metal device is available return nb::make_tuple(8, 0); + } else if (mx::cu::is_available()) { + return nb::make_tuple(13, 0); } else { // CPU device return nb::make_tuple(1, 0); diff --git a/python/src/device.cpp b/python/src/device.cpp index 85b15dd4d..006a05dc0 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -58,4 +58,9 @@ void init_device(nb::module_& m) { &mx::set_default_device, "device"_a, R"pbdoc(Set the default device.)pbdoc"); + m.def( + "is_available", + &mx::is_available, + "device"_a, + R"pbdoc(Check if a back-end is available for the given device.)pbdoc"); }