mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
This commit is contained in:
@@ -313,15 +313,15 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_map<size_t, nb::object>& tree_cache() {
|
||||
std::unordered_map<std::uintptr_t, nb::object>& tree_cache() {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<size_t, nb::object> tree_cache_;
|
||||
static std::unordered_map<std::uintptr_t, nb::object> tree_cache_;
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
struct PyCompiledFun {
|
||||
nb::callable fun;
|
||||
size_t fun_id;
|
||||
std::uintptr_t fun_id;
|
||||
nb::object captured_inputs;
|
||||
nb::object captured_outputs;
|
||||
bool shapeless;
|
||||
@@ -333,7 +333,7 @@ struct PyCompiledFun {
|
||||
nb::object outputs,
|
||||
bool shapeless)
|
||||
: fun(fun),
|
||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
||||
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())),
|
||||
captured_inputs(inputs),
|
||||
captured_outputs(outputs),
|
||||
shapeless(shapeless) {}
|
||||
@@ -342,7 +342,8 @@ struct PyCompiledFun {
|
||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
||||
PyCompiledFun(PyCompiledFun&& other)
|
||||
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
|
||||
: fun(std::move(other.fun)),
|
||||
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())) {
|
||||
other.fun_id = 0;
|
||||
captured_inputs = std::move(other.captured_inputs);
|
||||
captured_outputs = std::move(other.captured_outputs);
|
||||
|
Reference in New Issue
Block a user