mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
parent
020f048cd0
commit
0ab8e099e8
@ -14,6 +14,30 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct CompilerCache {
|
||||||
|
struct DLib {
|
||||||
|
DLib(const std::string& libname) {
|
||||||
|
lib = dlopen(libname.c_str(), RTLD_NOW);
|
||||||
|
if (!lib) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Could not load C++ shared library " << dlerror();
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~DLib() {
|
||||||
|
dlclose(lib);
|
||||||
|
}
|
||||||
|
void* lib;
|
||||||
|
};
|
||||||
|
// Statics to cache compiled libraries and functions
|
||||||
|
std::list<DLib> libs;
|
||||||
|
std::unordered_map<std::string, void*> kernels;
|
||||||
|
std::shared_mutex mtx;
|
||||||
|
};
|
||||||
|
|
||||||
|
static CompilerCache cache{};
|
||||||
|
|
||||||
// GPU compile is always available if the GPU is available and since we are in
|
// GPU compile is always available if the GPU is available and since we are in
|
||||||
// this file CPU compile is also available.
|
// this file CPU compile is also available.
|
||||||
namespace detail {
|
namespace detail {
|
||||||
@ -30,35 +54,15 @@ std::string get_temp_file(const std::string& name) {
|
|||||||
void* compile(
|
void* compile(
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::function<std::string(void)>& source_builder) {
|
const std::function<std::string(void)>& source_builder) {
|
||||||
struct DLib {
|
|
||||||
DLib(const std::string& libname) {
|
|
||||||
lib = dlopen(libname.c_str(), RTLD_NOW);
|
|
||||||
if (!lib) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Could not load C++ shared library " << dlerror();
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
~DLib() {
|
|
||||||
dlclose(lib);
|
|
||||||
}
|
|
||||||
void* lib;
|
|
||||||
};
|
|
||||||
// Statics to cache compiled libraries and functions
|
|
||||||
static std::list<DLib> libs;
|
|
||||||
static std::unordered_map<std::string, void*> kernels;
|
|
||||||
static std::shared_mutex compile_mtx;
|
|
||||||
|
|
||||||
{
|
{
|
||||||
std::shared_lock lock(compile_mtx);
|
std::shared_lock lock(cache.mtx);
|
||||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_lock lock(compile_mtx);
|
std::unique_lock lock(cache.mtx);
|
||||||
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::string source_code = source_builder();
|
std::string source_code = source_builder();
|
||||||
@ -112,10 +116,10 @@ void* compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
libs.emplace_back(shared_lib_path);
|
cache.libs.emplace_back(shared_lib_path);
|
||||||
|
|
||||||
// Load function
|
// Load function
|
||||||
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
|
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||||
if (!fun) {
|
if (!fun) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||||
@ -123,7 +127,7 @@ void* compile(
|
|||||||
<< dlerror();
|
<< dlerror();
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
kernels.insert({kernel_name, fun});
|
cache.kernels.insert({kernel_name, fun});
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,20 @@
|
|||||||
|
|
||||||
namespace mlx::core::detail {
|
namespace mlx::core::detail {
|
||||||
|
|
||||||
bool compile_available_for_device(const Device& device);
|
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||||
|
// idea.
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
std::uintptr_t fun_id,
|
||||||
|
bool shapeless = false,
|
||||||
|
std::vector<uint64_t> constants = {});
|
||||||
|
|
||||||
}
|
// Erase cached compile functions
|
||||||
|
void compile_erase(std::uintptr_t fun_id);
|
||||||
|
|
||||||
|
// Clear the compiler cache causing a recompilation of all compiled functions
|
||||||
|
// when called again.
|
||||||
|
void compile_clear_cache();
|
||||||
|
|
||||||
|
bool compile_available_for_device(const Device& device);
|
||||||
|
} // namespace mlx::core::detail
|
||||||
|
@ -16,21 +16,6 @@ std::vector<array> vmap_replace(
|
|||||||
const std::vector<int>& in_axes,
|
const std::vector<int>& in_axes,
|
||||||
const std::vector<int>& out_axes);
|
const std::vector<int>& out_axes);
|
||||||
|
|
||||||
// This is not part of the general C++ API as calling with a bad id is a bad
|
|
||||||
// idea.
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
||||||
std::uintptr_t fun_id,
|
|
||||||
bool shapeless = false,
|
|
||||||
std::vector<uint64_t> constants = {});
|
|
||||||
|
|
||||||
// Erase cached compile functions
|
|
||||||
void compile_erase(std::uintptr_t fun_id);
|
|
||||||
|
|
||||||
// Clear the compiler cache causing a recompilation of all compiled functions
|
|
||||||
// when called again.
|
|
||||||
void compile_clear_cache();
|
|
||||||
|
|
||||||
// Create an InTracing object during tracing operations to signify to the rest
|
// Create an InTracing object during tracing operations to signify to the rest
|
||||||
// of the codebase that we are during tracing so evals should not throw away
|
// of the codebase that we are during tracing so evals should not throw away
|
||||||
// the graph.
|
// the graph.
|
||||||
|
@ -1323,10 +1323,7 @@ void init_ops(nb::module_& m) {
|
|||||||
start (float or int, optional): Starting value which defaults to ``0``.
|
start (float or int, optional): Starting value which defaults to ``0``.
|
||||||
stop (float or int): Stopping value.
|
stop (float or int): Stopping value.
|
||||||
step (float or int, optional): Increment which defaults to ``1``.
|
step (float or int, optional): Increment which defaults to ``1``.
|
||||||
dtype (Dtype, optional): Specifies the data type of the output.
|
dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``.
|
||||||
If unspecified will default to ``float32`` if any of ``start``,
|
|
||||||
``stop``, or ``step`` are ``float``. Otherwise will default to
|
|
||||||
``int32``.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The range of values.
|
array: The range of values.
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
|
#include "mlx/compile_impl.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
|
@ -171,7 +171,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
# use torch to compute ct
|
# use torch to compute ct
|
||||||
out_pt.retain_grad()
|
out_pt.retain_grad()
|
||||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
out_pt.sum().backward()
|
||||||
|
|
||||||
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
|
pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy()
|
||||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
|
pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy()
|
||||||
@ -365,7 +365,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
# use torch to compute ct
|
# use torch to compute ct
|
||||||
out_pt.retain_grad()
|
out_pt.retain_grad()
|
||||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
out_pt.sum().backward()
|
||||||
|
|
||||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
|
pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy()
|
||||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
|
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy()
|
||||||
@ -549,7 +549,7 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
# use torch to compute ct
|
# use torch to compute ct
|
||||||
out_pt.retain_grad()
|
out_pt.retain_grad()
|
||||||
(out_pt - torch.randn_like(out_pt)).abs().sum().backward()
|
out_pt.sum().backward()
|
||||||
|
|
||||||
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
|
pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy()
|
||||||
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
|
pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy()
|
||||||
|
Loading…
Reference in New Issue
Block a user