mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-31 15:21:19 +08:00
Compare commits
4 Commits
011e59fb73
...
64af1f8920
Author | SHA1 | Date | |
---|---|---|---|
![]() |
64af1f8920 | ||
![]() |
b3d7b85376 | ||
![]() |
cad5c0241c | ||
![]() |
d2e0b0465c |
@ -42,6 +42,7 @@ option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
|||||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
message(
|
message(
|
||||||
@ -234,12 +235,16 @@ target_include_directories(
|
|||||||
# Do not add mlx_EXPORTS define for shared library.
|
# Do not add mlx_EXPORTS define for shared library.
|
||||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||||
|
|
||||||
|
if(USE_SYSTEM_FMT)
|
||||||
|
find_package(fmt REQUIRED)
|
||||||
|
else()
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
fmt
|
fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
GIT_TAG 10.2.1
|
GIT_TAG 10.2.1
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL)
|
||||||
FetchContent_MakeAvailable(fmt)
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
endif()
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
|
@ -106,7 +106,6 @@ void CudaAllocator::cuda_free(void* buf) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaFree(buf);
|
cudaFree(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,25 +63,30 @@ void copy_general(
|
|||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
size_t data_size = 1;
|
||||||
|
for (auto& s : shape)
|
||||||
|
data_size *= s;
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
data_size,
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<NDIM>(strides_in),
|
||||||
const_param<NDIM>(strides_out));
|
const_param<NDIM>(strides_out));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
data_size,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <future>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -107,6 +108,16 @@ void CommandEncoder::commit() {
|
|||||||
worker_.commit(stream_.last_cuda_stream());
|
worker_.commit(stream_.last_cuda_stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::synchronize() {
|
||||||
|
stream().synchronize();
|
||||||
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
|
std::future<void> f = p->get_future();
|
||||||
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
|
worker_.end_batch();
|
||||||
|
worker_.commit();
|
||||||
|
f.wait();
|
||||||
|
}
|
||||||
|
|
||||||
Device& device(mlx::core::Device device) {
|
Device& device(mlx::core::Device device) {
|
||||||
static std::unordered_map<int, Device> devices;
|
static std::unordered_map<int, Device> devices;
|
||||||
auto it = devices.find(device.index);
|
auto it = devices.find(device.index);
|
||||||
|
@ -123,6 +123,9 @@ class CommandEncoder {
|
|||||||
return has_gpu_work_;
|
return has_gpu_work_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait until kernels and completion handlers are finished
|
||||||
|
void synchronize();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Device& device_;
|
Device& device_;
|
||||||
DeviceStream& stream_;
|
DeviceStream& stream_;
|
||||||
|
@ -62,7 +62,7 @@ void finalize(Stream s) {
|
|||||||
|
|
||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
nvtx3::scoped_range r("gpu::synchronize");
|
nvtx3::scoped_range r("gpu::synchronize");
|
||||||
cu::get_stream(s).synchronize();
|
cu::get_command_encoder(s).synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::gpu
|
} // namespace mlx::core::gpu
|
||||||
|
@ -37,7 +37,8 @@ void check_cu_error(const char* name, CUresult err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the location of the CUDA toolkit.
|
// Return the location of the CUDA toolkit.
|
||||||
const char* cuda_home() {
|
const std::string& cuda_home() {
|
||||||
|
static std::string home = []() -> std::string {
|
||||||
const char* home = std::getenv("CUDA_HOME");
|
const char* home = std::getenv("CUDA_HOME");
|
||||||
if (home) {
|
if (home) {
|
||||||
return home;
|
return home;
|
||||||
@ -54,19 +55,28 @@ const char* cuda_home() {
|
|||||||
#endif
|
#endif
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||||
|
}();
|
||||||
|
return home;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the cache directory for storing compiled results.
|
// Get the cache directory for storing compiled results.
|
||||||
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
const std::filesystem::path& ptx_cache_dir() {
|
||||||
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||||
if (!std::filesystem::is_directory(path)) {
|
std::filesystem::path cache;
|
||||||
|
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
||||||
|
cache = c;
|
||||||
|
} else {
|
||||||
|
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||||
|
}
|
||||||
|
if (!std::filesystem::exists(cache)) {
|
||||||
std::error_code error;
|
std::error_code error;
|
||||||
if (!std::filesystem::create_directories(path, error)) {
|
if (!std::filesystem::create_directories(cache, error)) {
|
||||||
return false;
|
return std::filesystem::path();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*result = path;
|
return cache;
|
||||||
return true;
|
}();
|
||||||
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||||
@ -75,6 +85,10 @@ bool read_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
std::vector<char>* ptx,
|
std::vector<char>* ptx,
|
||||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||||
|
if (cache_dir.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||||
std::error_code error;
|
std::error_code error;
|
||||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||||
@ -105,6 +119,10 @@ void write_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::vector<char>& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||||
|
if (cache_dir.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||||
if (!ptx.empty()) {
|
if (!ptx.empty()) {
|
||||||
ptx_file.write(&ptx.front(), ptx.size());
|
ptx_file.write(&ptx.front(), ptx.size());
|
||||||
@ -184,11 +202,9 @@ JitModule::JitModule(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const KernelBuilder& builder) {
|
const KernelBuilder& builder) {
|
||||||
// Check cache.
|
// Check cache.
|
||||||
std::filesystem::path cache_dir;
|
|
||||||
std::vector<char> ptx;
|
std::vector<char> ptx;
|
||||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||||
if (!get_ptx_cache_dir(&cache_dir) ||
|
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
|
||||||
// Create program.
|
// Create program.
|
||||||
auto [source_code, kernel_names] = builder();
|
auto [source_code, kernel_names] = builder();
|
||||||
nvrtcProgram prog;
|
nvrtcProgram prog;
|
||||||
@ -246,7 +262,7 @@ JitModule::JitModule(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
}
|
}
|
||||||
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load module.
|
// Load module.
|
||||||
|
@ -80,7 +80,9 @@ void Worker::thread_fn() {
|
|||||||
}
|
}
|
||||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||||
}
|
}
|
||||||
for (auto& task : tasks) {
|
// Make sure tasks are cleared before the next wait
|
||||||
|
for (int i = 0; i < tasks.size(); ++i) {
|
||||||
|
auto task = std::move(tasks[i]);
|
||||||
task();
|
task();
|
||||||
}
|
}
|
||||||
worker_event_.wait(batch + 1);
|
worker_event_.wait(batch + 1);
|
||||||
|
@ -6,7 +6,6 @@ cuda_skip = {
|
|||||||
"TestEinsum.test_ellipses",
|
"TestEinsum.test_ellipses",
|
||||||
"TestEinsum.test_opt_einsum_test_cases",
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
"TestLoad.test_load_f8_e4m3",
|
"TestLoad.test_load_f8_e4m3",
|
||||||
"TestMemory.test_memory_info",
|
|
||||||
"TestLayers.test_group_norm",
|
"TestLayers.test_group_norm",
|
||||||
"TestLayers.test_pooling",
|
"TestLayers.test_pooling",
|
||||||
"TestLayers.test_quantized_embedding",
|
"TestLayers.test_quantized_embedding",
|
||||||
|
Loading…
Reference in New Issue
Block a user