Fix cpu segfault (#1488)

* fix cpu segfault

* nit in tests
This commit is contained in:
Awni Hannun 2024-10-14 16:17:03 -07:00 committed by GitHub
parent 020f048cd0
commit 0ab8e099e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 52 additions and 51 deletions

View File

@ -14,6 +14,30 @@
namespace mlx::core {
struct CompilerCache {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
}
~DLib() {
dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
std::list<DLib> libs;
std::unordered_map<std::string, void*> kernels;
std::shared_mutex mtx;
};
static CompilerCache cache{};
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
@ -30,35 +54,15 @@ std::string get_temp_file(const std::string& name) {
void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
}
~DLib() {
dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
static std::shared_mutex compile_mtx;
{
std::shared_lock lock(compile_mtx);
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
}
std::unique_lock lock(compile_mtx);
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
@ -112,10 +116,10 @@ void* compile(
}
// load library
libs.emplace_back(shared_lib_path);
cache.libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
@ -123,7 +127,7 @@ void* compile(
<< dlerror();
throw std::runtime_error(msg.str());
}
kernels.insert({kernel_name, fun});
cache.kernels.insert({kernel_name, fun});
return fun;
}

View File

@ -6,6 +6,20 @@
namespace mlx::core::detail {
bool compile_available_for_device(const Device& device);
// This is not part of the general C++ API as calling with a bad id is a bad
// idea.
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
std::uintptr_t fun_id,
bool shapeless = false,
std::vector<uint64_t> constants = {});
}
// Erase cached compile functions
void compile_erase(std::uintptr_t fun_id);
// Clear the compiler cache causing a recompilation of all compiled functions
// when called again.
void compile_clear_cache();
bool compile_available_for_device(const Device& device);
} // namespace mlx::core::detail

View File

@ -16,21 +16,6 @@ std::vector<array> vmap_replace(
const std::vector<int>& in_axes,
const std::vector<int>& out_axes);
// This is not part of the general C++ API as calling with a bad id is a bad
// idea.
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
std::uintptr_t fun_id,
bool shapeless = false,
std::vector<uint64_t> constants = {});
// Erase cached compile functions
void compile_erase(std::uintptr_t fun_id);
// Clear the compiler cache causing a recompilation of all compiled functions
// when called again.
void compile_clear_cache();
// Create an InTracing object during tracing operations to signify to the rest
// of the codebase that we are during tracing so evals should not throw away
// the graph.

View File

@ -1323,10 +1323,7 @@ void init_ops(nb::module_& m) {
start (float or int, optional): Starting value which defaults to ``0``.
stop (float or int): Stopping value.
step (float or int, optional): Increment which defaults to ``1``.
dtype (Dtype, optional): Specifies the data type of the output.
If unspecified will default to ``float32`` if any of ``start``,
``stop``, or ``step`` are ``float``. Otherwise will default to
``int32``.
dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``.
Returns:
array: The range of values.

View File

@ -13,6 +13,7 @@
#include "mlx/array.h"
#include "mlx/compile.h"
#include "mlx/compile_impl.h"
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"

View File

@ -171,7 +171,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
@ -365,7 +365,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
@ -549,7 +549,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
# use torch to compute ct
out_pt.retain_grad()
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
out_pt.sum().backward()
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()