From a7b404ff53b6119bf28bcce3f7d25ed595ad4f94 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 28 Mar 2024 22:37:59 +0900 Subject: [PATCH] Use uintptr_t instead of size_t to store funtion id (#916) Also does some small cleanup of the compile cache code. --- mlx/compile.cpp | 47 ++++++++++++++++++++------------------- mlx/transforms_impl.h | 4 ++-- python/src/transforms.cpp | 11 ++++----- 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 13b12fee3..28419351c 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -162,7 +162,6 @@ CompileMode& compile_mode() { return compile_mode_; } -using CompileFn = std::function(const std::vector&)>; using ParentsMap = std::unordered_map>>; @@ -189,17 +188,18 @@ void merge(array& dst, array& src, ParentsMap& parents_map) { }; template -size_t getAddress(std::function f) { - typedef T(fnType)(U...); - fnType** fnPointer = f.template target(); - if (fnPointer == nullptr) { +std::uintptr_t get_function_address(const std::function& fun) { + using FunType = T (*)(U...); + const FunType* fun_ptr = fun.template target(); + if (fun_ptr == nullptr) { throw std::invalid_argument( "[compile] Cannot compile a non-addressable function."); } - return (size_t)*fnPointer; + return reinterpret_cast(*fun_ptr); } -struct CompilerCache { +class CompilerCache { + public: struct CacheEntry { std::vector inputs; std::vector outputs; @@ -211,20 +211,20 @@ struct CompilerCache { // Returns a reference to a CacheEntry which can be updated // by the caller to avoid copying large tapes / inputs / outputs CacheEntry& find( - size_t fun_id, + std::uintptr_t fun_id, const std::vector& inputs, bool shapeless, const std::vector& constants) { - // Try to find the entry - auto [entry_it, inserted] = cache_.insert({fun_id, {}}); - auto& entries = entry_it->second; - auto is_match = [shapeless]( - const std::vector& in1, - const std::vector& in2) { + // Find the cache entries for |fun_id|. + std::vector& entries = cache_[fun_id]; + // Compare if 2 arrays have same shape and dtype. + auto has_same_shape_and_dtype = [shapeless]( + const std::vector& in1, + const std::vector& in2) { if (in1.size() != in2.size()) { return false; } - for (int i = 0; i < in1.size(); ++i) { + for (size_t i = 0; i < in1.size(); ++i) { if (in1[i].ndim() != in2[i].ndim()) { return false; } @@ -237,14 +237,14 @@ struct CompilerCache { } return true; }; - // Loop over entries and check inputs match i.e. shapes and types must be // equal. Note this could get really slow if one compiles the same // function with many different shapes. May want to store entries in a // more easily searchable structure. - for (auto& entry : entries) { + for (CacheEntry& entry : entries) { // Check the inputs match and return if so - if (is_match(inputs, entry.inputs) && constants == entry.constants) { + if (has_same_shape_and_dtype(inputs, entry.inputs) && + constants == entry.constants) { return entry; } } @@ -253,7 +253,7 @@ struct CompilerCache { return entries.back(); }; - void erase(size_t fun_id) { + void erase(std::uintptr_t fun_id) { cache_.erase(fun_id); } @@ -263,8 +263,9 @@ struct CompilerCache { // initialized before the compiler cache allocator::allocator(); } + friend CompilerCache& compiler_cache(); - std::unordered_map> cache_; + std::unordered_map> cache_; }; CompilerCache& compiler_cache() { @@ -774,7 +775,7 @@ void compile_validate_shapeless(const std::vector& tape) { std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, - size_t fun_id, + std::uintptr_t fun_id, bool shapeless /* = false */, std::vector constants /* = {} */) { if (compile_mode() == CompileMode::disabled || @@ -833,7 +834,7 @@ std::function(const std::vector&)> compile( }; } -void compile_erase(size_t fun_id) { +void compile_erase(std::uintptr_t fun_id) { detail::compiler_cache().erase(fun_id); } @@ -845,7 +846,7 @@ std::function(const std::vector&)> compile( if (detail::compile_mode() == CompileMode::disabled) { return fun; } - auto fun_id = detail::getAddress(fun); + auto fun_id = detail::get_function_address(fun); return detail::compile(fun, fun_id, shapeless); } diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 283dad194..f81dba24f 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -20,12 +20,12 @@ std::vector vmap_replace( // idea. std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, - size_t fun_id, + std::uintptr_t fun_id, bool shapeless = false, std::vector constants = {}); // Erase cached compile functions -void compile_erase(size_t fun_id); +void compile_erase(std::uintptr_t fun_id); // Create an InTracing object during tracing operations to signify to the rest // of the codebase that we are during tracing so evals should not throw away diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 215c441de..fa9f9235a 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -313,15 +313,15 @@ auto py_vmap( }; } -std::unordered_map& tree_cache() { +std::unordered_map& tree_cache() { // This map is used to Cache the tree structure of the outputs - static std::unordered_map tree_cache_; + static std::unordered_map tree_cache_; return tree_cache_; } struct PyCompiledFun { nb::callable fun; - size_t fun_id; + std::uintptr_t fun_id; nb::object captured_inputs; nb::object captured_outputs; bool shapeless; @@ -333,7 +333,7 @@ struct PyCompiledFun { nb::object outputs, bool shapeless) : fun(fun), - fun_id(reinterpret_cast(fun.ptr())), + fun_id(reinterpret_cast(fun.ptr())), captured_inputs(inputs), captured_outputs(outputs), shapeless(shapeless) {} @@ -342,7 +342,8 @@ struct PyCompiledFun { PyCompiledFun& operator=(const PyCompiledFun&) = delete; PyCompiledFun& operator=(PyCompiledFun&& other) = delete; PyCompiledFun(PyCompiledFun&& other) - : fun(std::move(other.fun)), fun_id(reinterpret_cast(fun.ptr())) { + : fun(std::move(other.fun)), + fun_id(reinterpret_cast(fun.ptr())) { other.fun_id = 0; captured_inputs = std::move(other.captured_inputs); captured_outputs = std::move(other.captured_outputs);