This commit is contained in:
Awni Hannun 2025-06-10 10:54:53 -07:00
parent 99c33d011d
commit 283a136c64
8 changed files with 47 additions and 5 deletions

View File

@ -55,6 +55,9 @@ endif()
if(MLX_BUILD_CUDA) if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif() endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)

View File

@ -12,6 +12,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu

11
mlx/backend/cuda/cuda.cpp Normal file
View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h"
namespace mlx::core::cu {
bool is_available() {
return true;
}
} // namespace mlx::core::cu

10
mlx/backend/cuda/cuda.h Normal file
View File

@ -0,0 +1,10 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cu {
/* Check if the CUDA backend is available. */
bool is_available();
} // namespace mlx::core::cu

View File

@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h"
namespace mlx::core::cu {
bool is_available() {
return false;
}
} // namespace mlx::core::cu

View File

@ -3,6 +3,7 @@
#pragma once #pragma once
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/compile.h" #include "mlx/compile.h"
#include "mlx/device.h" #include "mlx/device.h"

View File

@ -17,10 +17,7 @@
#include "python/src/indexing.h" #include "python/src/indexing.h"
#include "python/src/utils.h" #include "python/src/utils.h"
#include "mlx/device.h" #include "mlx/mlx.h"
#include "mlx/ops.h"
#include "mlx/transforms.h"
#include "mlx/utils.h"
namespace mx = mlx::core; namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
.def( .def(
"__dlpack_device__", "__dlpack_device__",
[](const mx::array& a) { [](const mx::array& a) {
// See
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
if (mx::metal::is_available()) { if (mx::metal::is_available()) {
// Metal device is available
return nb::make_tuple(8, 0); return nb::make_tuple(8, 0);
} else if (mx::cu::is_available()) {
return nb::make_tuple(13, 0);
} else { } else {
// CPU device // CPU device
return nb::make_tuple(1, 0); return nb::make_tuple(1, 0);

View File

@ -58,4 +58,9 @@ void init_device(nb::module_& m) {
&mx::set_default_device, &mx::set_default_device,
"device"_a, "device"_a,
R"pbdoc(Set the default device.)pbdoc"); R"pbdoc(Set the default device.)pbdoc");
m.def(
"is_available",
&mx::is_available,
"device"_a,
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
} }