Use uintptr_t instead of size_t to store funtion id (#916)

Also does some small cleanup of the compile cache code.
This commit is contained in:
Cheng 2024-03-28 22:37:59 +09:00 committed by GitHub
parent c4fd0e5ede
commit a7b404ff53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 30 deletions

View File

@ -162,7 +162,6 @@ CompileMode& compile_mode() {
return compile_mode_;
}
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
@ -189,17 +188,18 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
};
template <typename T, typename... U>
size_t getAddress(std::function<T(U...)> f) {
typedef T(fnType)(U...);
fnType** fnPointer = f.template target<fnType*>();
if (fnPointer == nullptr) {
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
using FunType = T (*)(U...);
const FunType* fun_ptr = fun.template target<FunType>();
if (fun_ptr == nullptr) {
throw std::invalid_argument(
"[compile] Cannot compile a non-addressable function.");
}
return (size_t)*fnPointer;
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
}
struct CompilerCache {
class CompilerCache {
public:
struct CacheEntry {
std::vector<array> inputs;
std::vector<array> outputs;
@ -211,20 +211,20 @@ struct CompilerCache {
// Returns a reference to a CacheEntry which can be updated
// by the caller to avoid copying large tapes / inputs / outputs
CacheEntry& find(
size_t fun_id,
std::uintptr_t fun_id,
const std::vector<array>& inputs,
bool shapeless,
const std::vector<uint64_t>& constants) {
// Try to find the entry
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
auto& entries = entry_it->second;
auto is_match = [shapeless](
// Find the cache entries for |fun_id|.
std::vector<CacheEntry>& entries = cache_[fun_id];
// Compare if 2 arrays have same shape and dtype.
auto has_same_shape_and_dtype = [shapeless](
const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) {
return false;
}
for (int i = 0; i < in1.size(); ++i) {
for (size_t i = 0; i < in1.size(); ++i) {
if (in1[i].ndim() != in2[i].ndim()) {
return false;
}
@ -237,14 +237,14 @@ struct CompilerCache {
}
return true;
};
// Loop over entries and check inputs match i.e. shapes and types must be
// equal. Note this could get really slow if one compiles the same
// function with many different shapes. May want to store entries in a
// more easily searchable structure.
for (auto& entry : entries) {
for (CacheEntry& entry : entries) {
// Check the inputs match and return if so
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
constants == entry.constants) {
return entry;
}
}
@ -253,7 +253,7 @@ struct CompilerCache {
return entries.back();
};
void erase(size_t fun_id) {
void erase(std::uintptr_t fun_id) {
cache_.erase(fun_id);
}
@ -263,8 +263,9 @@ struct CompilerCache {
// initialized before the compiler cache
allocator::allocator();
}
friend CompilerCache& compiler_cache();
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;
};
CompilerCache& compiler_cache() {
@ -774,7 +775,7 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id,
std::uintptr_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
if (compile_mode() == CompileMode::disabled ||
@ -833,7 +834,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
};
}
void compile_erase(size_t fun_id) {
void compile_erase(std::uintptr_t fun_id) {
detail::compiler_cache().erase(fun_id);
}
@ -845,7 +846,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
if (detail::compile_mode() == CompileMode::disabled) {
return fun;
}
auto fun_id = detail::getAddress(fun);
auto fun_id = detail::get_function_address(fun);
return detail::compile(fun, fun_id, shapeless);
}

View File

@ -20,12 +20,12 @@ std::vector<array> vmap_replace(
// idea.
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id,
std::uintptr_t fun_id,
bool shapeless = false,
std::vector<uint64_t> constants = {});
// Erase cached compile functions
void compile_erase(size_t fun_id);
void compile_erase(std::uintptr_t fun_id);
// Create an InTracing object during tracing operations to signify to the rest
// of the codebase that we are during tracing so evals should not throw away

View File

@ -313,15 +313,15 @@ auto py_vmap(
};
}
std::unordered_map<size_t, nb::object>& tree_cache() {
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
// This map is used to Cache the tree structure of the outputs
static std::unordered_map<size_t, nb::object> tree_cache_;
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
return tree_cache_;
}
struct PyCompiledFun {
nb::callable fun;
size_t fun_id;
std::uintptr_t fun_id;
nb::object captured_inputs;
nb::object captured_outputs;
bool shapeless;
@ -333,7 +333,7 @@ struct PyCompiledFun {
nb::object outputs,
bool shapeless)
: fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())),
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())),
captured_inputs(inputs),
captured_outputs(outputs),
shapeless(shapeless) {}
@ -342,7 +342,8 @@ struct PyCompiledFun {
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
PyCompiledFun(PyCompiledFun&& other)
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
: fun(std::move(other.fun)),
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())) {
other.fun_id = 0;
captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs);