From 0ab8e099e86c6d550d339b945d04a2e5f9f8050b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 16:17:03 -0700 Subject: [PATCH] Fix cpu segfault (#1488) * fix cpu segfault * nit in tests --- mlx/backend/common/compiled_cpu.cpp | 58 +++++++++++++++-------------- mlx/compile_impl.h | 18 ++++++++- mlx/transforms_impl.h | 15 -------- python/src/ops.cpp | 5 +-- python/src/transforms.cpp | 1 + python/tests/test_conv_transpose.py | 6 +-- 6 files changed, 52 insertions(+), 51 deletions(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 2c0df6073..5eb904b45 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -14,6 +14,30 @@ namespace mlx::core { +struct CompilerCache { + struct DLib { + DLib(const std::string& libname) { + lib = dlopen(libname.c_str(), RTLD_NOW); + if (!lib) { + std::ostringstream msg; + msg << "Could not load C++ shared library " << dlerror(); + throw std::runtime_error(msg.str()); + } + } + + ~DLib() { + dlclose(lib); + } + void* lib; + }; + // Statics to cache compiled libraries and functions + std::list libs; + std::unordered_map kernels; + std::shared_mutex mtx; +}; + +static CompilerCache cache{}; + // GPU compile is always available if the GPU is available and since we are in // this file CPU compile is also available. namespace detail { @@ -30,35 +54,15 @@ std::string get_temp_file(const std::string& name) { void* compile( const std::string& kernel_name, const std::function& source_builder) { - struct DLib { - DLib(const std::string& libname) { - lib = dlopen(libname.c_str(), RTLD_NOW); - if (!lib) { - std::ostringstream msg; - msg << "Could not load C++ shared library " << dlerror(); - throw std::runtime_error(msg.str()); - } - } - - ~DLib() { - dlclose(lib); - } - void* lib; - }; - // Statics to cache compiled libraries and functions - static std::list libs; - static std::unordered_map kernels; - static std::shared_mutex compile_mtx; - { - std::shared_lock lock(compile_mtx); - if (auto it = kernels.find(kernel_name); it != kernels.end()) { + std::shared_lock lock(cache.mtx); + if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { return it->second; } } - std::unique_lock lock(compile_mtx); - if (auto it = kernels.find(kernel_name); it != kernels.end()) { + std::unique_lock lock(cache.mtx); + if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { return it->second; } std::string source_code = source_builder(); @@ -112,10 +116,10 @@ void* compile( } // load library - libs.emplace_back(shared_lib_path); + cache.libs.emplace_back(shared_lib_path); // Load function - void* fun = dlsym(libs.back().lib, kernel_name.c_str()); + void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str()); if (!fun) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to load compiled function " @@ -123,7 +127,7 @@ void* compile( << dlerror(); throw std::runtime_error(msg.str()); } - kernels.insert({kernel_name, fun}); + cache.kernels.insert({kernel_name, fun}); return fun; } diff --git a/mlx/compile_impl.h b/mlx/compile_impl.h index 77ccedf38..0f18e1dac 100644 --- a/mlx/compile_impl.h +++ b/mlx/compile_impl.h @@ -6,6 +6,20 @@ namespace mlx::core::detail { -bool compile_available_for_device(const Device& device); +// This is not part of the general C++ API as calling with a bad id is a bad +// idea. +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun, + std::uintptr_t fun_id, + bool shapeless = false, + std::vector constants = {}); -} +// Erase cached compile functions +void compile_erase(std::uintptr_t fun_id); + +// Clear the compiler cache causing a recompilation of all compiled functions +// when called again. +void compile_clear_cache(); + +bool compile_available_for_device(const Device& device); +} // namespace mlx::core::detail diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 6f67305e8..c83f6795d 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -16,21 +16,6 @@ std::vector vmap_replace( const std::vector& in_axes, const std::vector& out_axes); -// This is not part of the general C++ API as calling with a bad id is a bad -// idea. -std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun, - std::uintptr_t fun_id, - bool shapeless = false, - std::vector constants = {}); - -// Erase cached compile functions -void compile_erase(std::uintptr_t fun_id); - -// Clear the compiler cache causing a recompilation of all compiled functions -// when called again. -void compile_clear_cache(); - // Create an InTracing object during tracing operations to signify to the rest // of the codebase that we are during tracing so evals should not throw away // the graph. diff --git a/python/src/ops.cpp b/python/src/ops.cpp index b0b308a91..4ffa21dd9 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1323,10 +1323,7 @@ void init_ops(nb::module_& m) { start (float or int, optional): Starting value which defaults to ``0``. stop (float or int): Stopping value. step (float or int, optional): Increment which defaults to ``1``. - dtype (Dtype, optional): Specifies the data type of the output. - If unspecified will default to ``float32`` if any of ``start``, - ``stop``, or ``step`` are ``float``. Otherwise will default to - ``int32``. + dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``. Returns: array: The range of values. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 82759cfcc..29564a707 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -13,6 +13,7 @@ #include "mlx/array.h" #include "mlx/compile.h" +#include "mlx/compile_impl.h" #include "mlx/graph_utils.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py index 0efff048d..7b458914d 100644 --- a/python/tests/test_conv_transpose.py +++ b/python/tests/test_conv_transpose.py @@ -171,7 +171,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase): # use torch to compute ct out_pt.retain_grad() - (out_pt - torch.randn_like(out_pt)).abs().sum().backward() + out_pt.sum().backward() pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy() pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy() @@ -365,7 +365,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase): # use torch to compute ct out_pt.retain_grad() - (out_pt - torch.randn_like(out_pt)).abs().sum().backward() + out_pt.sum().backward() pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy() pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy() @@ -549,7 +549,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase): # use torch to compute ct out_pt.retain_grad() - (out_pt - torch.randn_like(out_pt)).abs().sum().backward() + out_pt.sum().backward() pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy() pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()