mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
3dcb286baf
...
a9bac3d9e5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9bac3d9e5 | ||
|
|
5458d43247 | ||
|
|
a4dba65220 |
@@ -232,6 +232,7 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
@@ -240,6 +241,18 @@ jobs:
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
- run:
|
||||
name: CCache report
|
||||
command: |
|
||||
|
||||
@@ -140,6 +140,12 @@ elseif(MLX_BUILD_METAL)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||
# With newer clang/gcc versions following libs are implicitly linked, but when
|
||||
# building on old distributions they need to be explicitly listed.
|
||||
target_link_libraries(mlx PRIVATE dl pthread)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
|
||||
@@ -29,11 +29,18 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
|
||||
int cuda_graph_cache_size() {
|
||||
static int cache_size = []() {
|
||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
|
||||
}();
|
||||
return cache_size;
|
||||
}
|
||||
|
||||
bool use_cuda_graphs() {
|
||||
static bool use_graphs = []() {
|
||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
}();
|
||||
return use_graphs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
@@ -86,11 +93,18 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.device().make_current();
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
|
||||
graph.end_capture(enc.stream());
|
||||
if (discard) {
|
||||
return;
|
||||
@@ -105,6 +119,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||
|
||||
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||
enc.in_concurrent_ = false;
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Use an empty graph node for synchronization
|
||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
||||
@@ -193,11 +210,18 @@ void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(const array& arr) {
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(const array& arr) {
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
active_outputs_.push_back(id);
|
||||
@@ -215,6 +239,11 @@ void CommandEncoder::add_kernel_node(
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
if (!use_cuda_graphs()) {
|
||||
CHECK_CUDA_ERROR(cudaLaunchKernel(
|
||||
func, grid_dim, block_dim, params, smem_bytes, stream()));
|
||||
return;
|
||||
}
|
||||
cudaKernelNodeParams kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
kernel_params.gridDim = grid_dim;
|
||||
@@ -230,6 +259,22 @@ void CommandEncoder::add_kernel_node(
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
if (!use_cuda_graphs()) {
|
||||
CHECK_CUDA_ERROR(cuLaunchKernel(
|
||||
func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
smem_bytes,
|
||||
stream(),
|
||||
params,
|
||||
nullptr));
|
||||
return;
|
||||
}
|
||||
|
||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
kernel_params.gridDimX = grid_dim.x;
|
||||
@@ -256,6 +301,12 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
}
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
if (!use_cuda_graphs()) {
|
||||
CudaGraphExec graph_exec;
|
||||
graph_exec.instantiate(child);
|
||||
device_.make_current();
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
|
||||
}
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
|
||||
@@ -76,9 +76,6 @@ class CommandEncoder {
|
||||
uint32_t smem_bytes,
|
||||
void** params);
|
||||
|
||||
// Low-level graph helpers.
|
||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||
void add_graph_node(cudaGraph_t child);
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
@@ -101,6 +98,9 @@ class CommandEncoder {
|
||||
void synchronize();
|
||||
|
||||
private:
|
||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||
|
||||
struct GraphNode {
|
||||
cudaGraphNode_t node;
|
||||
// K = kernel
|
||||
|
||||
@@ -67,9 +67,11 @@ const std::string& cccl_dir() {
|
||||
return path.string();
|
||||
}
|
||||
// Finally check the environment variable.
|
||||
path = std::getenv("MLX_CCCL_DIR");
|
||||
if (!path.empty() && std::filesystem::exists(path)) {
|
||||
return path.string();
|
||||
if (const char* env = std::getenv("MLX_CCCL_DIR"); env) {
|
||||
path = env;
|
||||
if (!path.empty() && std::filesystem::exists(path)) {
|
||||
return path.string();
|
||||
}
|
||||
}
|
||||
return std::string();
|
||||
}();
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@@ -62,17 +63,28 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
save_file = os.path.join(self.test_dir, f"mlx_path.npy")
|
||||
save_arr = mx.ones((32,))
|
||||
mx.save(Path(save_file), save_arr)
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
load_arr = mx.load(Path(save_file))
|
||||
self.assertTrue(mx.array_equal(load_arr, save_arr))
|
||||
|
||||
def test_save_and_load_safetensors(self):
|
||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
|
||||
mx.save_safetensors(
|
||||
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
)
|
||||
res = mx.load(test_file, return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
for obj in [str, Path]:
|
||||
mx.save_safetensors(
|
||||
obj(test_file),
|
||||
{"test": mx.ones((2, 2))},
|
||||
{"testing": "test", "format": "mlx"},
|
||||
)
|
||||
res = mx.load(obj(test_file), return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
|
||||
for dt in self.dtypes + ["bfloat16"]:
|
||||
with self.subTest(dtype=dt):
|
||||
@@ -128,6 +140,13 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||
)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_path_test_fs.gguf")
|
||||
save_dict = {"test": mx.ones(shape)}
|
||||
mx.save_gguf(Path(save_file_mlx), save_dict)
|
||||
load_dict = mx.load(Path(save_file_mlx))
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
|
||||
def test_load_f8_e4m3(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
@@ -10,7 +10,7 @@ using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test device placement") {
|
||||
auto device = default_device();
|
||||
Device d = metal::is_available() ? Device::gpu : Device::cpu;
|
||||
Device d = gpu::is_available() ? Device::gpu : Device::cpu;
|
||||
if (std::getenv("DEVICE") == nullptr) {
|
||||
CHECK_EQ(device, d);
|
||||
}
|
||||
@@ -18,7 +18,7 @@ TEST_CASE("test device placement") {
|
||||
array x(1.0f);
|
||||
array y(1.0f);
|
||||
auto z = add(x, y, default_device());
|
||||
if (metal::is_available()) {
|
||||
if (gpu::is_available()) {
|
||||
z = add(x, y, Device::gpu);
|
||||
z = add(x, y, Device(Device::gpu, 0));
|
||||
} else {
|
||||
|
||||
@@ -16,7 +16,7 @@ TEST_CASE("test stream management") {
|
||||
CHECK_NE(s1, s2);
|
||||
|
||||
// Check that default streams have the correct devices
|
||||
if (metal::is_available()) {
|
||||
if (gpu::is_available()) {
|
||||
auto s_gpu = default_stream(Device::gpu);
|
||||
CHECK_EQ(s_gpu.device, Device::gpu);
|
||||
} else {
|
||||
@@ -28,7 +28,7 @@ TEST_CASE("test stream management") {
|
||||
s_cpu = new_stream(Device::cpu);
|
||||
CHECK_EQ(s_cpu.device, Device::cpu);
|
||||
|
||||
if (metal::is_available()) {
|
||||
if (gpu::is_available()) {
|
||||
auto s_gpu = new_stream(Device::gpu);
|
||||
CHECK_EQ(s_gpu.device, Device::gpu);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user