mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
020f048cd0
commit
0ab8e099e8
@ -14,6 +14,30 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct CompilerCache {
|
||||
struct DLib {
|
||||
DLib(const std::string& libname) {
|
||||
lib = dlopen(libname.c_str(), RTLD_NOW);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Could not load C++ shared library " << dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
~DLib() {
|
||||
dlclose(lib);
|
||||
}
|
||||
void* lib;
|
||||
};
|
||||
// Statics to cache compiled libraries and functions
|
||||
std::list<DLib> libs;
|
||||
std::unordered_map<std::string, void*> kernels;
|
||||
std::shared_mutex mtx;
|
||||
};
|
||||
|
||||
static CompilerCache cache{};
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
namespace detail {
|
||||
@ -30,35 +54,15 @@ std::string get_temp_file(const std::string& name) {
|
||||
void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
struct DLib {
|
||||
DLib(const std::string& libname) {
|
||||
lib = dlopen(libname.c_str(), RTLD_NOW);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Could not load C++ shared library " << dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
~DLib() {
|
||||
dlclose(lib);
|
||||
}
|
||||
void* lib;
|
||||
};
|
||||
// Statics to cache compiled libraries and functions
|
||||
static std::list<DLib> libs;
|
||||
static std::unordered_map<std::string, void*> kernels;
|
||||
static std::shared_mutex compile_mtx;
|
||||
|
||||
{
|
||||
std::shared_lock lock(compile_mtx);
|
||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||
std::shared_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(compile_mtx);
|
||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||
std::unique_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string source_code = source_builder();
|
||||
@ -112,10 +116,10 @@ void* compile(
|
||||
}
|
||||
|
||||
// load library
|
||||
libs.emplace_back(shared_lib_path);
|
||||
cache.libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
|
||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
@ -123,7 +127,7 @@ void* compile(
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
kernels.insert({kernel_name, fun});
|
||||
cache.kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,20 @@
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
bool compile_available_for_device(const Device& device);
|
||||
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
||||
}
|
||||
// Erase cached compile functions
|
||||
void compile_erase(std::uintptr_t fun_id);
|
||||
|
||||
// Clear the compiler cache causing a recompilation of all compiled functions
|
||||
// when called again.
|
||||
void compile_clear_cache();
|
||||
|
||||
bool compile_available_for_device(const Device& device);
|
||||
} // namespace mlx::core::detail
|
||||
|
@ -16,21 +16,6 @@ std::vector<array> vmap_replace(
|
||||
const std::vector<int>& in_axes,
|
||||
const std::vector<int>& out_axes);
|
||||
|
||||
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
||||
// Erase cached compile functions
|
||||
void compile_erase(std::uintptr_t fun_id);
|
||||
|
||||
// Clear the compiler cache causing a recompilation of all compiled functions
|
||||
// when called again.
|
||||
void compile_clear_cache();
|
||||
|
||||
// Create an InTracing object during tracing operations to signify to the rest
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
|
@ -1323,10 +1323,7 @@ void init_ops(nb::module_& m) {
|
||||
start (float or int, optional): Starting value which defaults to ``0``.
|
||||
stop (float or int): Stopping value.
|
||||
step (float or int, optional): Increment which defaults to ``1``.
|
||||
dtype (Dtype, optional): Specifies the data type of the output.
|
||||
If unspecified will default to ``float32`` if any of ``start``,
|
||||
``stop``, or ``step`` are ``float``. Otherwise will default to
|
||||
``int32``.
|
||||
dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``.
|
||||
|
||||
Returns:
|
||||
array: The range of values.
|
||||
|
@ -13,6 +13,7 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
@ -171,7 +171,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
out_pt.sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
|
||||
@ -365,7 +365,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
out_pt.sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
|
||||
@ -549,7 +549,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
# use torch to compute ct
|
||||
out_pt.retain_grad()
|
||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
||||
out_pt.sum().backward()
|
||||
|
||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
|
||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
|
||||
|
Loading…
Reference in New Issue
Block a user