mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 19:51:13 +08:00
rebase
This commit is contained in:
parent
99c33d011d
commit
283a136c64
@ -55,6 +55,9 @@ endif()
|
|||||||
|
|
||||||
if(MLX_BUILD_CUDA)
|
if(MLX_BUILD_CUDA)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||||
|
else()
|
||||||
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||||
|
@ -12,6 +12,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
|
11
mlx/backend/cuda/cuda.cpp
Normal file
11
mlx/backend/cuda/cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
10
mlx/backend/cuda/cuda.h
Normal file
10
mlx/backend/cuda/cuda.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
/* Check if the CUDA backend is available. */
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
11
mlx/backend/cuda/no_cuda.cpp
Normal file
11
mlx/backend/cuda/no_cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
|
@ -17,10 +17,7 @@
|
|||||||
#include "python/src/indexing.h"
|
#include "python/src/indexing.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
#include "mlx/device.h"
|
#include "mlx/mlx.h"
|
||||||
#include "mlx/ops.h"
|
|
||||||
#include "mlx/transforms.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"__dlpack_device__",
|
"__dlpack_device__",
|
||||||
[](const mx::array& a) {
|
[](const mx::array& a) {
|
||||||
|
// See
|
||||||
|
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
|
||||||
if (mx::metal::is_available()) {
|
if (mx::metal::is_available()) {
|
||||||
// Metal device is available
|
|
||||||
return nb::make_tuple(8, 0);
|
return nb::make_tuple(8, 0);
|
||||||
|
} else if (mx::cu::is_available()) {
|
||||||
|
return nb::make_tuple(13, 0);
|
||||||
} else {
|
} else {
|
||||||
// CPU device
|
// CPU device
|
||||||
return nb::make_tuple(1, 0);
|
return nb::make_tuple(1, 0);
|
||||||
|
@ -58,4 +58,9 @@ void init_device(nb::module_& m) {
|
|||||||
&mx::set_default_device,
|
&mx::set_default_device,
|
||||||
"device"_a,
|
"device"_a,
|
||||||
R"pbdoc(Set the default device.)pbdoc");
|
R"pbdoc(Set the default device.)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"is_available",
|
||||||
|
&mx::is_available,
|
||||||
|
"device"_a,
|
||||||
|
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user