Compare commits

...

3 Commits

Author SHA1 Message Date
Cheng
a9bac3d9e5 Run CPP tests for CUDA build in CI (#2544) 2025-08-27 08:06:46 +09:00
Awni Hannun
5458d43247 add load with path tests (#2543) 2025-08-26 14:24:47 -07:00
Awni Hannun
a4dba65220 Enable cuda graph toggle (#2545)
* enable cuda graph toggle

* increase cache size
2025-08-26 12:50:38 -07:00
9 changed files with 109 additions and 17 deletions

View File

@@ -232,6 +232,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
uv venv uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v uv pip install -e ".[dev]" -v
- run: - run:
@@ -240,6 +241,18 @@ jobs:
source .venv/bin/activate source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run: - run:
name: CCache report name: CCache report
command: | command: |

View File

@@ -140,6 +140,12 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif() endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32) if(WIN32)
if(MSVC) if(MSVC)
# GGUF does not build with MSVC. # GGUF does not build with MSVC.

View File

@@ -29,11 +29,18 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
int cuda_graph_cache_size() { int cuda_graph_cache_size() {
static int cache_size = []() { static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
}(); }();
return cache_size; return cache_size;
} }
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
}();
return use_graphs;
}
} // namespace } // namespace
Device::Device(int device) : device_(device) { Device::Device(int device) : device_(device) {
@@ -86,11 +93,18 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current(); enc.device().make_current();
if (!use_cuda_graphs()) {
return;
}
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
} }
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
return;
}
graph.end_capture(enc.stream()); graph.end_capture(enc.stream());
if (discard) { if (discard) {
return; return;
@@ -105,6 +119,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
CommandEncoder::ConcurrentContext::~ConcurrentContext() { CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false; enc.in_concurrent_ = false;
if (!use_cuda_graphs()) {
return;
}
// Use an empty graph node for synchronization // Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
@@ -193,11 +210,18 @@ void CommandEncoder::add_completed_handler(std::function<void()> task) {
} }
void CommandEncoder::set_input_array(const array& arr) { void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr()); auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id); active_deps_.push_back(id);
} }
void CommandEncoder::set_output_array(const array& arr) { void CommandEncoder::set_output_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr()); auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id); active_deps_.push_back(id);
active_outputs_.push_back(id); active_outputs_.push_back(id);
@@ -215,6 +239,11 @@ void CommandEncoder::add_kernel_node(
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
}
cudaKernelNodeParams kernel_params = {0}; cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
kernel_params.gridDim = grid_dim; kernel_params.gridDim = grid_dim;
@@ -230,6 +259,22 @@ void CommandEncoder::add_kernel_node(
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
smem_bytes,
stream(),
params,
nullptr));
return;
}
CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x; kernel_params.gridDimX = grid_dim.x;
@@ -256,6 +301,12 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
} }
void CommandEncoder::add_graph_node(cudaGraph_t child) { void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
}
cudaGraphNode_t node; cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'}); insert_graph_dependencies(GraphNode{node, 'G'});

View File

@@ -76,9 +76,6 @@ class CommandEncoder {
uint32_t smem_bytes, uint32_t smem_bytes,
void** params); void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child); void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) { void add_temporary(const array& arr) {
@@ -101,6 +98,9 @@ class CommandEncoder {
void synchronize(); void synchronize();
private: private:
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
struct GraphNode { struct GraphNode {
cudaGraphNode_t node; cudaGraphNode_t node;
// K = kernel // K = kernel

View File

@@ -67,9 +67,11 @@ const std::string& cccl_dir() {
return path.string(); return path.string();
} }
// Finally check the environment variable. // Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR"); if (const char* env = std::getenv("MLX_CCCL_DIR"); env) {
if (!path.empty() && std::filesystem::exists(path)) { path = env;
return path.string(); if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
} }
return std::string(); return std::string();
}(); }();

View File

@@ -4,6 +4,7 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/compile.h" #include "mlx/compile.h"
#include "mlx/device.h" #include "mlx/device.h"

View File

@@ -3,6 +3,7 @@
import os import os
import tempfile import tempfile
import unittest import unittest
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
@@ -62,17 +63,28 @@ class TestLoad(mlx_tests.MLXTestCase):
load_arr_mlx_npy = np.load(save_file_mlx) load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
save_file = os.path.join(self.test_dir, f"mlx_path.npy")
save_arr = mx.ones((32,))
mx.save(Path(save_file), save_arr)
# Load array saved by mlx as mlx array
load_arr = mx.load(Path(save_file))
self.assertTrue(mx.array_equal(load_arr, save_arr))
def test_save_and_load_safetensors(self): def test_save_and_load_safetensors(self):
test_file = os.path.join(self.test_dir, "test.safetensors") test_file = os.path.join(self.test_dir, "test.safetensors")
with self.assertRaises(Exception): with self.assertRaises(Exception):
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0}) mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors( for obj in [str, Path]:
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"} mx.save_safetensors(
) obj(test_file),
res = mx.load(test_file, return_metadata=True) {"test": mx.ones((2, 2))},
self.assertEqual(len(res), 2) {"testing": "test", "format": "mlx"},
self.assertEqual(res[1], {"testing": "test", "format": "mlx"}) )
res = mx.load(obj(test_file), return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
for dt in self.dtypes + ["bfloat16"]: for dt in self.dtypes + ["bfloat16"]:
with self.subTest(dtype=dt): with self.subTest(dtype=dt):
@@ -128,6 +140,13 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.array_equal(load_dict["test"], save_dict["test"]) mx.array_equal(load_dict["test"], save_dict["test"])
) )
save_file_mlx = os.path.join(self.test_dir, f"mlx_path_test_fs.gguf")
save_dict = {"test": mx.ones(shape)}
mx.save_gguf(Path(save_file_mlx), save_dict)
load_dict = mx.load(Path(save_file_mlx))
self.assertTrue("test" in load_dict)
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
def test_load_f8_e4m3(self): def test_load_f8_e4m3(self):
if not os.path.isdir(self.test_dir): if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir) os.mkdir(self.test_dir)

View File

@@ -10,7 +10,7 @@ using namespace mlx::core;
TEST_CASE("test device placement") { TEST_CASE("test device placement") {
auto device = default_device(); auto device = default_device();
Device d = metal::is_available() ? Device::gpu : Device::cpu; Device d = gpu::is_available() ? Device::gpu : Device::cpu;
if (std::getenv("DEVICE") == nullptr) { if (std::getenv("DEVICE") == nullptr) {
CHECK_EQ(device, d); CHECK_EQ(device, d);
} }
@@ -18,7 +18,7 @@ TEST_CASE("test device placement") {
array x(1.0f); array x(1.0f);
array y(1.0f); array y(1.0f);
auto z = add(x, y, default_device()); auto z = add(x, y, default_device());
if (metal::is_available()) { if (gpu::is_available()) {
z = add(x, y, Device::gpu); z = add(x, y, Device::gpu);
z = add(x, y, Device(Device::gpu, 0)); z = add(x, y, Device(Device::gpu, 0));
} else { } else {

View File

@@ -16,7 +16,7 @@ TEST_CASE("test stream management") {
CHECK_NE(s1, s2); CHECK_NE(s1, s2);
// Check that default streams have the correct devices // Check that default streams have the correct devices
if (metal::is_available()) { if (gpu::is_available()) {
auto s_gpu = default_stream(Device::gpu); auto s_gpu = default_stream(Device::gpu);
CHECK_EQ(s_gpu.device, Device::gpu); CHECK_EQ(s_gpu.device, Device::gpu);
} else { } else {
@@ -28,7 +28,7 @@ TEST_CASE("test stream management") {
s_cpu = new_stream(Device::cpu); s_cpu = new_stream(Device::cpu);
CHECK_EQ(s_cpu.device, Device::cpu); CHECK_EQ(s_cpu.device, Device::cpu);
if (metal::is_available()) { if (gpu::is_available()) {
auto s_gpu = new_stream(Device::gpu); auto s_gpu = new_stream(Device::gpu);
CHECK_EQ(s_gpu.device, Device::gpu); CHECK_EQ(s_gpu.device, Device::gpu);
} else { } else {