mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
This commit is contained in:
parent
c4fd0e5ede
commit
a7b404ff53
@ -162,7 +162,6 @@ CompileMode& compile_mode() {
|
||||
return compile_mode_;
|
||||
}
|
||||
|
||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
@ -189,17 +188,18 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||
};
|
||||
|
||||
template <typename T, typename... U>
|
||||
size_t getAddress(std::function<T(U...)> f) {
|
||||
typedef T(fnType)(U...);
|
||||
fnType** fnPointer = f.template target<fnType*>();
|
||||
if (fnPointer == nullptr) {
|
||||
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||
using FunType = T (*)(U...);
|
||||
const FunType* fun_ptr = fun.template target<FunType>();
|
||||
if (fun_ptr == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Cannot compile a non-addressable function.");
|
||||
}
|
||||
return (size_t)*fnPointer;
|
||||
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
|
||||
}
|
||||
|
||||
struct CompilerCache {
|
||||
class CompilerCache {
|
||||
public:
|
||||
struct CacheEntry {
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> outputs;
|
||||
@ -211,20 +211,20 @@ struct CompilerCache {
|
||||
// Returns a reference to a CacheEntry which can be updated
|
||||
// by the caller to avoid copying large tapes / inputs / outputs
|
||||
CacheEntry& find(
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless,
|
||||
const std::vector<uint64_t>& constants) {
|
||||
// Try to find the entry
|
||||
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
||||
auto& entries = entry_it->second;
|
||||
auto is_match = [shapeless](
|
||||
// Find the cache entries for |fun_id|.
|
||||
std::vector<CacheEntry>& entries = cache_[fun_id];
|
||||
// Compare if 2 arrays have same shape and dtype.
|
||||
auto has_same_shape_and_dtype = [shapeless](
|
||||
const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
if (in1.size() != in2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < in1.size(); ++i) {
|
||||
for (size_t i = 0; i < in1.size(); ++i) {
|
||||
if (in1[i].ndim() != in2[i].ndim()) {
|
||||
return false;
|
||||
}
|
||||
@ -237,14 +237,14 @@ struct CompilerCache {
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// Loop over entries and check inputs match i.e. shapes and types must be
|
||||
// equal. Note this could get really slow if one compiles the same
|
||||
// function with many different shapes. May want to store entries in a
|
||||
// more easily searchable structure.
|
||||
for (auto& entry : entries) {
|
||||
for (CacheEntry& entry : entries) {
|
||||
// Check the inputs match and return if so
|
||||
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
|
||||
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
|
||||
constants == entry.constants) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
@ -253,7 +253,7 @@ struct CompilerCache {
|
||||
return entries.back();
|
||||
};
|
||||
|
||||
void erase(size_t fun_id) {
|
||||
void erase(std::uintptr_t fun_id) {
|
||||
cache_.erase(fun_id);
|
||||
}
|
||||
|
||||
@ -263,8 +263,9 @@ struct CompilerCache {
|
||||
// initialized before the compiler cache
|
||||
allocator::allocator();
|
||||
}
|
||||
|
||||
friend CompilerCache& compiler_cache();
|
||||
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
|
||||
std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;
|
||||
};
|
||||
|
||||
CompilerCache& compiler_cache() {
|
||||
@ -774,7 +775,7 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
if (compile_mode() == CompileMode::disabled ||
|
||||
@ -833,7 +834,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
};
|
||||
}
|
||||
|
||||
void compile_erase(size_t fun_id) {
|
||||
void compile_erase(std::uintptr_t fun_id) {
|
||||
detail::compiler_cache().erase(fun_id);
|
||||
}
|
||||
|
||||
@ -845,7 +846,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
if (detail::compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
auto fun_id = detail::getAddress(fun);
|
||||
auto fun_id = detail::get_function_address(fun);
|
||||
return detail::compile(fun, fun_id, shapeless);
|
||||
}
|
||||
|
||||
|
@ -20,12 +20,12 @@ std::vector<array> vmap_replace(
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
||||
// Erase cached compile functions
|
||||
void compile_erase(size_t fun_id);
|
||||
void compile_erase(std::uintptr_t fun_id);
|
||||
|
||||
// Create an InTracing object during tracing operations to signify to the rest
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
|
@ -313,15 +313,15 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_map<size_t, nb::object>& tree_cache() {
|
||||
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<size_t, nb::object> tree_cache_;
|
||||
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
struct PyCompiledFun {
|
||||
nb::callable fun;
|
||||
size_t fun_id;
|
||||
std::uintptr_t fun_id;
|
||||
nb::object captured_inputs;
|
||||
nb::object captured_outputs;
|
||||
bool shapeless;
|
||||
@ -333,7 +333,7 @@ struct PyCompiledFun {
|
||||
nb::object outputs,
|
||||
bool shapeless)
|
||||
: fun(fun),
|
||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
||||
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())),
|
||||
captured_inputs(inputs),
|
||||
captured_outputs(outputs),
|
||||
shapeless(shapeless) {}
|
||||
@ -342,7 +342,8 @@ struct PyCompiledFun {
|
||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
||||
PyCompiledFun(PyCompiledFun&& other)
|
||||
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
|
||||
: fun(std::move(other.fun)),
|
||||
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())) {
|
||||
other.fun_id = 0;
|
||||
captured_inputs = std::move(other.captured_inputs);
|
||||
captured_outputs = std::move(other.captured_outputs);
|
||||
|
Loading…
Reference in New Issue
Block a user