Compare commits

...

3 Commits

Author SHA1 Message Date
Cheng
a9bac3d9e5 Run CPP tests for CUDA build in CI (#2544) 2025-08-27 08:06:46 +09:00
Awni Hannun
5458d43247 add load with path tests (#2543) 2025-08-26 14:24:47 -07:00
Awni Hannun
a4dba65220 Enable cuda graph toggle (#2545)
* enable cuda graph toggle

* increase cache size
2025-08-26 12:50:38 -07:00
9 changed files with 109 additions and 17 deletions

View File

@@ -232,6 +232,7 @@ jobs:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
@@ -240,6 +241,18 @@ jobs:
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |

View File

@@ -140,6 +140,12 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.

View File

@@ -29,11 +29,18 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
}();
return cache_size;
}
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
}();
return use_graphs;
}
} // namespace
Device::Device(int device) : device_(device) {
@@ -86,11 +93,18 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current();
if (!use_cuda_graphs()) {
return;
}
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
return;
}
graph.end_capture(enc.stream());
if (discard) {
return;
@@ -105,6 +119,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false;
if (!use_cuda_graphs()) {
return;
}
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
@@ -193,11 +210,18 @@ void CommandEncoder::add_completed_handler(std::function<void()> task) {
}
void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
}
void CommandEncoder::set_output_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
active_outputs_.push_back(id);
@@ -215,6 +239,11 @@ void CommandEncoder::add_kernel_node(
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
}
cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDim = grid_dim;
@@ -230,6 +259,22 @@ void CommandEncoder::add_kernel_node(
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
smem_bytes,
stream(),
params,
nullptr));
return;
}
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x;
@@ -256,6 +301,12 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
}
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});

View File

@@ -76,9 +76,6 @@ class CommandEncoder {
uint32_t smem_bytes,
void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) {
@@ -101,6 +98,9 @@ class CommandEncoder {
void synchronize();
private:
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
struct GraphNode {
cudaGraphNode_t node;
// K = kernel

View File

@@ -67,9 +67,11 @@ const std::string& cccl_dir() {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
if (const char* env = std::getenv("MLX_CCCL_DIR"); env) {
path = env;
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
}
return std::string();
}();

View File

@@ -4,6 +4,7 @@
#include "mlx/array.h"
#include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/compile.h"
#include "mlx/device.h"

View File

@@ -3,6 +3,7 @@
import os
import tempfile
import unittest
from pathlib import Path
import mlx.core as mx
import mlx_tests
@@ -62,17 +63,28 @@ class TestLoad(mlx_tests.MLXTestCase):
load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
save_file = os.path.join(self.test_dir, f"mlx_path.npy")
save_arr = mx.ones((32,))
mx.save(Path(save_file), save_arr)
# Load array saved by mlx as mlx array
load_arr = mx.load(Path(save_file))
self.assertTrue(mx.array_equal(load_arr, save_arr))
def test_save_and_load_safetensors(self):
test_file = os.path.join(self.test_dir, "test.safetensors")
with self.assertRaises(Exception):
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors(
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
)
res = mx.load(test_file, return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
for obj in [str, Path]:
mx.save_safetensors(
obj(test_file),
{"test": mx.ones((2, 2))},
{"testing": "test", "format": "mlx"},
)
res = mx.load(obj(test_file), return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
for dt in self.dtypes + ["bfloat16"]:
with self.subTest(dtype=dt):
@@ -128,6 +140,13 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.array_equal(load_dict["test"], save_dict["test"])
)
save_file_mlx = os.path.join(self.test_dir, f"mlx_path_test_fs.gguf")
save_dict = {"test": mx.ones(shape)}
mx.save_gguf(Path(save_file_mlx), save_dict)
load_dict = mx.load(Path(save_file_mlx))
self.assertTrue("test" in load_dict)
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
def test_load_f8_e4m3(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)

View File

@@ -10,7 +10,7 @@ using namespace mlx::core;
TEST_CASE("test device placement") {
auto device = default_device();
Device d = metal::is_available() ? Device::gpu : Device::cpu;
Device d = gpu::is_available() ? Device::gpu : Device::cpu;
if (std::getenv("DEVICE") == nullptr) {
CHECK_EQ(device, d);
}
@@ -18,7 +18,7 @@ TEST_CASE("test device placement") {
array x(1.0f);
array y(1.0f);
auto z = add(x, y, default_device());
if (metal::is_available()) {
if (gpu::is_available()) {
z = add(x, y, Device::gpu);
z = add(x, y, Device(Device::gpu, 0));
} else {

View File

@@ -16,7 +16,7 @@ TEST_CASE("test stream management") {
CHECK_NE(s1, s2);
// Check that default streams have the correct devices
if (metal::is_available()) {
if (gpu::is_available()) {
auto s_gpu = default_stream(Device::gpu);
CHECK_EQ(s_gpu.device, Device::gpu);
} else {
@@ -28,7 +28,7 @@ TEST_CASE("test stream management") {
s_cpu = new_stream(Device::cpu);
CHECK_EQ(s_cpu.device, Device::cpu);
if (metal::is_available()) {
if (gpu::is_available()) {
auto s_gpu = new_stream(Device::gpu);
CHECK_EQ(s_gpu.device, Device::gpu);
} else {