mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
This commit is contained in:
parent
c4fd0e5ede
commit
a7b404ff53
@ -162,7 +162,6 @@ CompileMode& compile_mode() {
|
|||||||
return compile_mode_;
|
return compile_mode_;
|
||||||
}
|
}
|
||||||
|
|
||||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
|
||||||
using ParentsMap =
|
using ParentsMap =
|
||||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||||
|
|
||||||
@ -189,17 +188,18 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename... U>
|
template <typename T, typename... U>
|
||||||
size_t getAddress(std::function<T(U...)> f) {
|
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||||
typedef T(fnType)(U...);
|
using FunType = T (*)(U...);
|
||||||
fnType** fnPointer = f.template target<fnType*>();
|
const FunType* fun_ptr = fun.template target<FunType>();
|
||||||
if (fnPointer == nullptr) {
|
if (fun_ptr == nullptr) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[compile] Cannot compile a non-addressable function.");
|
"[compile] Cannot compile a non-addressable function.");
|
||||||
}
|
}
|
||||||
return (size_t)*fnPointer;
|
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CompilerCache {
|
class CompilerCache {
|
||||||
|
public:
|
||||||
struct CacheEntry {
|
struct CacheEntry {
|
||||||
std::vector<array> inputs;
|
std::vector<array> inputs;
|
||||||
std::vector<array> outputs;
|
std::vector<array> outputs;
|
||||||
@ -211,20 +211,20 @@ struct CompilerCache {
|
|||||||
// Returns a reference to a CacheEntry which can be updated
|
// Returns a reference to a CacheEntry which can be updated
|
||||||
// by the caller to avoid copying large tapes / inputs / outputs
|
// by the caller to avoid copying large tapes / inputs / outputs
|
||||||
CacheEntry& find(
|
CacheEntry& find(
|
||||||
size_t fun_id,
|
std::uintptr_t fun_id,
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
bool shapeless,
|
bool shapeless,
|
||||||
const std::vector<uint64_t>& constants) {
|
const std::vector<uint64_t>& constants) {
|
||||||
// Try to find the entry
|
// Find the cache entries for |fun_id|.
|
||||||
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
std::vector<CacheEntry>& entries = cache_[fun_id];
|
||||||
auto& entries = entry_it->second;
|
// Compare if 2 arrays have same shape and dtype.
|
||||||
auto is_match = [shapeless](
|
auto has_same_shape_and_dtype = [shapeless](
|
||||||
const std::vector<array>& in1,
|
const std::vector<array>& in1,
|
||||||
const std::vector<array>& in2) {
|
const std::vector<array>& in2) {
|
||||||
if (in1.size() != in2.size()) {
|
if (in1.size() != in2.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < in1.size(); ++i) {
|
for (size_t i = 0; i < in1.size(); ++i) {
|
||||||
if (in1[i].ndim() != in2[i].ndim()) {
|
if (in1[i].ndim() != in2[i].ndim()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -237,14 +237,14 @@ struct CompilerCache {
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Loop over entries and check inputs match i.e. shapes and types must be
|
// Loop over entries and check inputs match i.e. shapes and types must be
|
||||||
// equal. Note this could get really slow if one compiles the same
|
// equal. Note this could get really slow if one compiles the same
|
||||||
// function with many different shapes. May want to store entries in a
|
// function with many different shapes. May want to store entries in a
|
||||||
// more easily searchable structure.
|
// more easily searchable structure.
|
||||||
for (auto& entry : entries) {
|
for (CacheEntry& entry : entries) {
|
||||||
// Check the inputs match and return if so
|
// Check the inputs match and return if so
|
||||||
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
|
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
|
||||||
|
constants == entry.constants) {
|
||||||
return entry;
|
return entry;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -253,7 +253,7 @@ struct CompilerCache {
|
|||||||
return entries.back();
|
return entries.back();
|
||||||
};
|
};
|
||||||
|
|
||||||
void erase(size_t fun_id) {
|
void erase(std::uintptr_t fun_id) {
|
||||||
cache_.erase(fun_id);
|
cache_.erase(fun_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,8 +263,9 @@ struct CompilerCache {
|
|||||||
// initialized before the compiler cache
|
// initialized before the compiler cache
|
||||||
allocator::allocator();
|
allocator::allocator();
|
||||||
}
|
}
|
||||||
|
|
||||||
friend CompilerCache& compiler_cache();
|
friend CompilerCache& compiler_cache();
|
||||||
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
|
std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
CompilerCache& compiler_cache() {
|
CompilerCache& compiler_cache() {
|
||||||
@ -774,7 +775,7 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
|
|||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
size_t fun_id,
|
std::uintptr_t fun_id,
|
||||||
bool shapeless /* = false */,
|
bool shapeless /* = false */,
|
||||||
std::vector<uint64_t> constants /* = {} */) {
|
std::vector<uint64_t> constants /* = {} */) {
|
||||||
if (compile_mode() == CompileMode::disabled ||
|
if (compile_mode() == CompileMode::disabled ||
|
||||||
@ -833,7 +834,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void compile_erase(size_t fun_id) {
|
void compile_erase(std::uintptr_t fun_id) {
|
||||||
detail::compiler_cache().erase(fun_id);
|
detail::compiler_cache().erase(fun_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -845,7 +846,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
if (detail::compile_mode() == CompileMode::disabled) {
|
if (detail::compile_mode() == CompileMode::disabled) {
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
auto fun_id = detail::getAddress(fun);
|
auto fun_id = detail::get_function_address(fun);
|
||||||
return detail::compile(fun, fun_id, shapeless);
|
return detail::compile(fun, fun_id, shapeless);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,12 +20,12 @@ std::vector<array> vmap_replace(
|
|||||||
// idea.
|
// idea.
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
size_t fun_id,
|
std::uintptr_t fun_id,
|
||||||
bool shapeless = false,
|
bool shapeless = false,
|
||||||
std::vector<uint64_t> constants = {});
|
std::vector<uint64_t> constants = {});
|
||||||
|
|
||||||
// Erase cached compile functions
|
// Erase cached compile functions
|
||||||
void compile_erase(size_t fun_id);
|
void compile_erase(std::uintptr_t fun_id);
|
||||||
|
|
||||||
// Create an InTracing object during tracing operations to signify to the rest
|
// Create an InTracing object during tracing operations to signify to the rest
|
||||||
// of the codebase that we are during tracing so evals should not throw away
|
// of the codebase that we are during tracing so evals should not throw away
|
||||||
|
@ -313,15 +313,15 @@ auto py_vmap(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<size_t, nb::object>& tree_cache() {
|
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
|
||||||
// This map is used to Cache the tree structure of the outputs
|
// This map is used to Cache the tree structure of the outputs
|
||||||
static std::unordered_map<size_t, nb::object> tree_cache_;
|
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
|
||||||
return tree_cache_;
|
return tree_cache_;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PyCompiledFun {
|
struct PyCompiledFun {
|
||||||
nb::callable fun;
|
nb::callable fun;
|
||||||
size_t fun_id;
|
std::uintptr_t fun_id;
|
||||||
nb::object captured_inputs;
|
nb::object captured_inputs;
|
||||||
nb::object captured_outputs;
|
nb::object captured_outputs;
|
||||||
bool shapeless;
|
bool shapeless;
|
||||||
@ -333,7 +333,7 @@ struct PyCompiledFun {
|
|||||||
nb::object outputs,
|
nb::object outputs,
|
||||||
bool shapeless)
|
bool shapeless)
|
||||||
: fun(fun),
|
: fun(fun),
|
||||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())),
|
||||||
captured_inputs(inputs),
|
captured_inputs(inputs),
|
||||||
captured_outputs(outputs),
|
captured_outputs(outputs),
|
||||||
shapeless(shapeless) {}
|
shapeless(shapeless) {}
|
||||||
@ -342,7 +342,8 @@ struct PyCompiledFun {
|
|||||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||||
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
||||||
PyCompiledFun(PyCompiledFun&& other)
|
PyCompiledFun(PyCompiledFun&& other)
|
||||||
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
|
: fun(std::move(other.fun)),
|
||||||
|
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())) {
|
||||||
other.fun_id = 0;
|
other.fun_id = 0;
|
||||||
captured_inputs = std::move(other.captured_inputs);
|
captured_inputs = std::move(other.captured_inputs);
|
||||||
captured_outputs = std::move(other.captured_outputs);
|
captured_outputs = std::move(other.captured_outputs);
|
||||||
|
Loading…
Reference in New Issue
Block a user