mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
3dcb286baf
...
a9bac3d9e5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9bac3d9e5 | ||
|
|
5458d43247 | ||
|
|
a4dba65220 |
@@ -232,6 +232,7 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
uv venv
|
uv venv
|
||||||
|
uv pip install cmake
|
||||||
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
uv pip install -e ".[dev]" -v
|
uv pip install -e ".[dev]" -v
|
||||||
- run:
|
- run:
|
||||||
@@ -240,6 +241,18 @@ jobs:
|
|||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||||
|
- run:
|
||||||
|
name: Build CPP only
|
||||||
|
command: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
cmake . -B build \
|
||||||
|
-DMLX_BUILD_CUDA=ON \
|
||||||
|
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
||||||
|
-DCMAKE_BUILD_TYPE=DEBUG
|
||||||
|
cmake --build build -j `nproc`
|
||||||
|
- run:
|
||||||
|
name: Run CPP tests
|
||||||
|
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||||
- run:
|
- run:
|
||||||
name: CCache report
|
name: CCache report
|
||||||
command: |
|
command: |
|
||||||
|
|||||||
@@ -140,6 +140,12 @@ elseif(MLX_BUILD_METAL)
|
|||||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
# With newer clang/gcc versions following libs are implicitly linked, but when
|
||||||
|
# building on old distributions they need to be explicitly listed.
|
||||||
|
target_link_libraries(mlx PRIVATE dl pthread)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
# GGUF does not build with MSVC.
|
# GGUF does not build with MSVC.
|
||||||
|
|||||||
@@ -29,11 +29,18 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
|||||||
|
|
||||||
int cuda_graph_cache_size() {
|
int cuda_graph_cache_size() {
|
||||||
static int cache_size = []() {
|
static int cache_size = []() {
|
||||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
|
||||||
}();
|
}();
|
||||||
return cache_size;
|
return cache_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool use_cuda_graphs() {
|
||||||
|
static bool use_graphs = []() {
|
||||||
|
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
|
}();
|
||||||
|
return use_graphs;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
@@ -86,11 +93,18 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
|||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
enc.device().make_current();
|
enc.device().make_current();
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
graph.end_capture(enc.stream());
|
graph.end_capture(enc.stream());
|
||||||
if (discard) {
|
if (discard) {
|
||||||
return;
|
return;
|
||||||
@@ -105,6 +119,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
|||||||
|
|
||||||
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||||
enc.in_concurrent_ = false;
|
enc.in_concurrent_ = false;
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Use an empty graph node for synchronization
|
// Use an empty graph node for synchronization
|
||||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
||||||
@@ -193,11 +210,18 @@ void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_input_array(const array& arr) {
|
void CommandEncoder::set_input_array(const array& arr) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||||
active_deps_.push_back(id);
|
active_deps_.push_back(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_output_array(const array& arr) {
|
void CommandEncoder::set_output_array(const array& arr) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||||
active_deps_.push_back(id);
|
active_deps_.push_back(id);
|
||||||
active_outputs_.push_back(id);
|
active_outputs_.push_back(id);
|
||||||
@@ -215,6 +239,11 @@ void CommandEncoder::add_kernel_node(
|
|||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
CHECK_CUDA_ERROR(cudaLaunchKernel(
|
||||||
|
func, grid_dim, block_dim, params, smem_bytes, stream()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
cudaKernelNodeParams kernel_params = {0};
|
cudaKernelNodeParams kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
@@ -230,6 +259,22 @@ void CommandEncoder::add_kernel_node(
|
|||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
CHECK_CUDA_ERROR(cuLaunchKernel(
|
||||||
|
func,
|
||||||
|
grid_dim.x,
|
||||||
|
grid_dim.y,
|
||||||
|
grid_dim.z,
|
||||||
|
block_dim.x,
|
||||||
|
block_dim.y,
|
||||||
|
block_dim.z,
|
||||||
|
smem_bytes,
|
||||||
|
stream(),
|
||||||
|
params,
|
||||||
|
nullptr));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDimX = grid_dim.x;
|
kernel_params.gridDimX = grid_dim.x;
|
||||||
@@ -256,6 +301,12 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
|
if (!use_cuda_graphs()) {
|
||||||
|
CudaGraphExec graph_exec;
|
||||||
|
graph_exec.instantiate(child);
|
||||||
|
device_.make_current();
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
|
||||||
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||||
|
|||||||
@@ -76,9 +76,6 @@ class CommandEncoder {
|
|||||||
uint32_t smem_bytes,
|
uint32_t smem_bytes,
|
||||||
void** params);
|
void** params);
|
||||||
|
|
||||||
// Low-level graph helpers.
|
|
||||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
|
||||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
|
||||||
void add_graph_node(cudaGraph_t child);
|
void add_graph_node(cudaGraph_t child);
|
||||||
|
|
||||||
void add_temporary(const array& arr) {
|
void add_temporary(const array& arr) {
|
||||||
@@ -101,6 +98,9 @@ class CommandEncoder {
|
|||||||
void synchronize();
|
void synchronize();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||||
|
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||||
|
|
||||||
struct GraphNode {
|
struct GraphNode {
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
// K = kernel
|
// K = kernel
|
||||||
|
|||||||
@@ -67,9 +67,11 @@ const std::string& cccl_dir() {
|
|||||||
return path.string();
|
return path.string();
|
||||||
}
|
}
|
||||||
// Finally check the environment variable.
|
// Finally check the environment variable.
|
||||||
path = std::getenv("MLX_CCCL_DIR");
|
if (const char* env = std::getenv("MLX_CCCL_DIR"); env) {
|
||||||
if (!path.empty() && std::filesystem::exists(path)) {
|
path = env;
|
||||||
return path.string();
|
if (!path.empty() && std::filesystem::exists(path)) {
|
||||||
|
return path.string();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return std::string();
|
return std::string();
|
||||||
}();
|
}();
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/cuda/cuda.h"
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
#include "mlx/backend/gpu/available.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
@@ -62,17 +63,28 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||||
|
|
||||||
|
save_file = os.path.join(self.test_dir, f"mlx_path.npy")
|
||||||
|
save_arr = mx.ones((32,))
|
||||||
|
mx.save(Path(save_file), save_arr)
|
||||||
|
|
||||||
|
# Load array saved by mlx as mlx array
|
||||||
|
load_arr = mx.load(Path(save_file))
|
||||||
|
self.assertTrue(mx.array_equal(load_arr, save_arr))
|
||||||
|
|
||||||
def test_save_and_load_safetensors(self):
|
def test_save_and_load_safetensors(self):
|
||||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||||
|
|
||||||
mx.save_safetensors(
|
for obj in [str, Path]:
|
||||||
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
mx.save_safetensors(
|
||||||
)
|
obj(test_file),
|
||||||
res = mx.load(test_file, return_metadata=True)
|
{"test": mx.ones((2, 2))},
|
||||||
self.assertEqual(len(res), 2)
|
{"testing": "test", "format": "mlx"},
|
||||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
)
|
||||||
|
res = mx.load(obj(test_file), return_metadata=True)
|
||||||
|
self.assertEqual(len(res), 2)
|
||||||
|
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||||
|
|
||||||
for dt in self.dtypes + ["bfloat16"]:
|
for dt in self.dtypes + ["bfloat16"]:
|
||||||
with self.subTest(dtype=dt):
|
with self.subTest(dtype=dt):
|
||||||
@@ -128,6 +140,13 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
save_file_mlx = os.path.join(self.test_dir, f"mlx_path_test_fs.gguf")
|
||||||
|
save_dict = {"test": mx.ones(shape)}
|
||||||
|
mx.save_gguf(Path(save_file_mlx), save_dict)
|
||||||
|
load_dict = mx.load(Path(save_file_mlx))
|
||||||
|
self.assertTrue("test" in load_dict)
|
||||||
|
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||||
|
|
||||||
def test_load_f8_e4m3(self):
|
def test_load_f8_e4m3(self):
|
||||||
if not os.path.isdir(self.test_dir):
|
if not os.path.isdir(self.test_dir):
|
||||||
os.mkdir(self.test_dir)
|
os.mkdir(self.test_dir)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ using namespace mlx::core;
|
|||||||
|
|
||||||
TEST_CASE("test device placement") {
|
TEST_CASE("test device placement") {
|
||||||
auto device = default_device();
|
auto device = default_device();
|
||||||
Device d = metal::is_available() ? Device::gpu : Device::cpu;
|
Device d = gpu::is_available() ? Device::gpu : Device::cpu;
|
||||||
if (std::getenv("DEVICE") == nullptr) {
|
if (std::getenv("DEVICE") == nullptr) {
|
||||||
CHECK_EQ(device, d);
|
CHECK_EQ(device, d);
|
||||||
}
|
}
|
||||||
@@ -18,7 +18,7 @@ TEST_CASE("test device placement") {
|
|||||||
array x(1.0f);
|
array x(1.0f);
|
||||||
array y(1.0f);
|
array y(1.0f);
|
||||||
auto z = add(x, y, default_device());
|
auto z = add(x, y, default_device());
|
||||||
if (metal::is_available()) {
|
if (gpu::is_available()) {
|
||||||
z = add(x, y, Device::gpu);
|
z = add(x, y, Device::gpu);
|
||||||
z = add(x, y, Device(Device::gpu, 0));
|
z = add(x, y, Device(Device::gpu, 0));
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ TEST_CASE("test stream management") {
|
|||||||
CHECK_NE(s1, s2);
|
CHECK_NE(s1, s2);
|
||||||
|
|
||||||
// Check that default streams have the correct devices
|
// Check that default streams have the correct devices
|
||||||
if (metal::is_available()) {
|
if (gpu::is_available()) {
|
||||||
auto s_gpu = default_stream(Device::gpu);
|
auto s_gpu = default_stream(Device::gpu);
|
||||||
CHECK_EQ(s_gpu.device, Device::gpu);
|
CHECK_EQ(s_gpu.device, Device::gpu);
|
||||||
} else {
|
} else {
|
||||||
@@ -28,7 +28,7 @@ TEST_CASE("test stream management") {
|
|||||||
s_cpu = new_stream(Device::cpu);
|
s_cpu = new_stream(Device::cpu);
|
||||||
CHECK_EQ(s_cpu.device, Device::cpu);
|
CHECK_EQ(s_cpu.device, Device::cpu);
|
||||||
|
|
||||||
if (metal::is_available()) {
|
if (gpu::is_available()) {
|
||||||
auto s_gpu = new_stream(Device::gpu);
|
auto s_gpu = new_stream(Device::gpu);
|
||||||
CHECK_EQ(s_gpu.device, Device::gpu);
|
CHECK_EQ(s_gpu.device, Device::gpu);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Reference in New Issue
Block a user