mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
start cuda circle config (#2256)
* rebase * fix metal kernel linking issue on cuda * start cuda circle config
This commit is contained in:
parent
8590c0941e
commit
c35f4d089a
@ -212,6 +212,29 @@ jobs:
|
|||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
machine:
|
||||||
|
image: linux-cuda-12:default
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install Python package
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
python -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
@ -348,6 +371,7 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
|
- cuda_build_and_test
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
@ -455,6 +479,8 @@ workflows:
|
|||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
|
- cuda_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
|
@ -55,6 +55,9 @@ endif()
|
|||||||
|
|
||||||
if(MLX_BUILD_CUDA)
|
if(MLX_BUILD_CUDA)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||||
|
else()
|
||||||
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||||
|
@ -12,6 +12,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
|
11
mlx/backend/cuda/cuda.cpp
Normal file
11
mlx/backend/cuda/cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
10
mlx/backend/cuda/cuda.h
Normal file
10
mlx/backend/cuda/cuda.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
/* Check if the CUDA backend is available. */
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
11
mlx/backend/cuda/no_cuda.cpp
Normal file
11
mlx/backend/cuda/no_cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -3,8 +3,11 @@
|
|||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/fast.h"
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return false;
|
return false;
|
||||||
@ -19,4 +22,21 @@ device_info() {
|
|||||||
"[metal::device_info] Cannot get device info without metal backend");
|
"[metal::device_info] Cannot get device info without metal backend");
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace metal
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
MetalKernelFunction metal_kernel(
|
||||||
|
const std::string&,
|
||||||
|
const std::vector<std::string>&,
|
||||||
|
const std::vector<std::string>&,
|
||||||
|
const std::string&,
|
||||||
|
const std::string&,
|
||||||
|
bool ensure_row_contiguous,
|
||||||
|
bool atomic_outputs) {
|
||||||
|
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/fast.h"
|
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE)
|
|||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
NO_GPU_MULTI(AffineQuantize)
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
|
|
||||||
MetalKernelFunction metal_kernel(
|
|
||||||
const std::string&,
|
|
||||||
const std::vector<std::string>&,
|
|
||||||
const std::vector<std::string>&,
|
|
||||||
const std::string&,
|
|
||||||
const std::string&,
|
|
||||||
bool ensure_row_contiguous,
|
|
||||||
bool atomic_outputs) {
|
|
||||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
|
@ -17,10 +17,7 @@
|
|||||||
#include "python/src/indexing.h"
|
#include "python/src/indexing.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
#include "mlx/device.h"
|
#include "mlx/mlx.h"
|
||||||
#include "mlx/ops.h"
|
|
||||||
#include "mlx/transforms.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"__dlpack_device__",
|
"__dlpack_device__",
|
||||||
[](const mx::array& a) {
|
[](const mx::array& a) {
|
||||||
|
// See
|
||||||
|
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
|
||||||
if (mx::metal::is_available()) {
|
if (mx::metal::is_available()) {
|
||||||
// Metal device is available
|
|
||||||
return nb::make_tuple(8, 0);
|
return nb::make_tuple(8, 0);
|
||||||
|
} else if (mx::cu::is_available()) {
|
||||||
|
return nb::make_tuple(13, 0);
|
||||||
} else {
|
} else {
|
||||||
// CPU device
|
// CPU device
|
||||||
return nb::make_tuple(1, 0);
|
return nb::make_tuple(1, 0);
|
||||||
|
@ -58,4 +58,9 @@ void init_device(nb::module_& m) {
|
|||||||
&mx::set_default_device,
|
&mx::set_default_device,
|
||||||
"device"_a,
|
"device"_a,
|
||||||
R"pbdoc(Set the default device.)pbdoc");
|
R"pbdoc(Set the default device.)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"is_available",
|
||||||
|
&mx::is_available,
|
||||||
|
"device"_a,
|
||||||
|
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase):
|
|||||||
def test_dlx_device_type(self):
|
def test_dlx_device_type(self):
|
||||||
a = mx.array([1, 2, 3])
|
a = mx.array([1, 2, 3])
|
||||||
device_type, device_id = a.__dlpack_device__()
|
device_type, device_id = a.__dlpack_device__()
|
||||||
self.assertIn(device_type, [1, 8])
|
self.assertIn(device_type, [1, 8, 13])
|
||||||
self.assertEqual(device_id, 0)
|
self.assertEqual(device_id, 0)
|
||||||
|
|
||||||
if device_type == 8:
|
if device_type == 8:
|
||||||
|
@ -10,7 +10,7 @@ import mlx_tests
|
|||||||
class TestDefaultDevice(unittest.TestCase):
|
class TestDefaultDevice(unittest.TestCase):
|
||||||
def test_mlx_default_device(self):
|
def test_mlx_default_device(self):
|
||||||
device = mx.default_device()
|
device = mx.default_device()
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
self.assertEqual(device, mx.Device(mx.gpu))
|
self.assertEqual(device, mx.Device(mx.gpu))
|
||||||
self.assertEqual(str(device), "Device(gpu, 0)")
|
self.assertEqual(str(device), "Device(gpu, 0)")
|
||||||
self.assertEqual(device, mx.gpu)
|
self.assertEqual(device, mx.gpu)
|
||||||
@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(s2.device, mx.default_device())
|
self.assertEqual(s2.device, mx.default_device())
|
||||||
self.assertNotEqual(s1, s2)
|
self.assertNotEqual(s1, s2)
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
s_gpu = mx.default_stream(mx.gpu)
|
s_gpu = mx.default_stream(mx.gpu)
|
||||||
self.assertEqual(s_gpu.device, mx.gpu)
|
self.assertEqual(s_gpu.device, mx.gpu)
|
||||||
else:
|
else:
|
||||||
@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
|||||||
s_cpu = mx.new_stream(mx.cpu)
|
s_cpu = mx.new_stream(mx.cpu)
|
||||||
self.assertEqual(s_cpu.device, mx.cpu)
|
self.assertEqual(s_cpu.device, mx.cpu)
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
s_gpu = mx.new_stream(mx.gpu)
|
s_gpu = mx.new_stream(mx.gpu)
|
||||||
self.assertEqual(s_gpu.device, mx.gpu)
|
self.assertEqual(s_gpu.device, mx.gpu)
|
||||||
else:
|
else:
|
||||||
@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
||||||
|
|
||||||
if mx.metal.is_available():
|
if mx.is_available(mx.gpu):
|
||||||
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
||||||
self.assertEqual(a.item(), b.item())
|
self.assertEqual(a.item(), b.item())
|
||||||
s_gpu = mx.new_stream(mx.gpu)
|
s_gpu = mx.new_stream(mx.gpu)
|
||||||
|
@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||||
|
|
||||||
|
|
||||||
class TestSchedulers(unittest.TestCase):
|
class TestSchedulers(mlx_tests.MLXTestCase):
|
||||||
def test_decay_lr(self):
|
def test_decay_lr(self):
|
||||||
for optim_class in optimizers_dict.values():
|
for optim_class in optimizers_dict.values():
|
||||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user