Use uintptr_t instead of size_t to store funtion id (#916)

Also does some small cleanup of the compile cache code.
This commit is contained in:
Cheng
2024-03-28 22:37:59 +09:00
committed by GitHub
parent c4fd0e5ede
commit a7b404ff53
3 changed files with 32 additions and 30 deletions

View File

@@ -313,15 +313,15 @@ auto py_vmap(
};
}
std::unordered_map<size_t, nb::object>& tree_cache() {
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
// This map is used to Cache the tree structure of the outputs
static std::unordered_map<size_t, nb::object> tree_cache_;
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
return tree_cache_;
}
struct PyCompiledFun {
nb::callable fun;
size_t fun_id;
std::uintptr_t fun_id;
nb::object captured_inputs;
nb::object captured_outputs;
bool shapeless;
@@ -333,7 +333,7 @@ struct PyCompiledFun {
nb::object outputs,
bool shapeless)
: fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())),
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())),
captured_inputs(inputs),
captured_outputs(outputs),
shapeless(shapeless) {}
@@ -342,7 +342,8 @@ struct PyCompiledFun {
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
PyCompiledFun(PyCompiledFun&& other)
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
: fun(std::move(other.fun)),
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())) {
other.fun_id = 0;
captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs);