mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
This commit is contained in:
@@ -162,7 +162,6 @@ CompileMode& compile_mode() {
|
||||
return compile_mode_;
|
||||
}
|
||||
|
||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
@@ -189,17 +188,18 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||
};
|
||||
|
||||
template <typename T, typename... U>
|
||||
size_t getAddress(std::function<T(U...)> f) {
|
||||
typedef T(fnType)(U...);
|
||||
fnType** fnPointer = f.template target<fnType*>();
|
||||
if (fnPointer == nullptr) {
|
||||
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||
using FunType = T (*)(U...);
|
||||
const FunType* fun_ptr = fun.template target<FunType>();
|
||||
if (fun_ptr == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Cannot compile a non-addressable function.");
|
||||
}
|
||||
return (size_t)*fnPointer;
|
||||
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
|
||||
}
|
||||
|
||||
struct CompilerCache {
|
||||
class CompilerCache {
|
||||
public:
|
||||
struct CacheEntry {
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> outputs;
|
||||
@@ -211,20 +211,20 @@ struct CompilerCache {
|
||||
// Returns a reference to a CacheEntry which can be updated
|
||||
// by the caller to avoid copying large tapes / inputs / outputs
|
||||
CacheEntry& find(
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless,
|
||||
const std::vector<uint64_t>& constants) {
|
||||
// Try to find the entry
|
||||
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
||||
auto& entries = entry_it->second;
|
||||
auto is_match = [shapeless](
|
||||
const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
// Find the cache entries for |fun_id|.
|
||||
std::vector<CacheEntry>& entries = cache_[fun_id];
|
||||
// Compare if 2 arrays have same shape and dtype.
|
||||
auto has_same_shape_and_dtype = [shapeless](
|
||||
const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
if (in1.size() != in2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < in1.size(); ++i) {
|
||||
for (size_t i = 0; i < in1.size(); ++i) {
|
||||
if (in1[i].ndim() != in2[i].ndim()) {
|
||||
return false;
|
||||
}
|
||||
@@ -237,14 +237,14 @@ struct CompilerCache {
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// Loop over entries and check inputs match i.e. shapes and types must be
|
||||
// equal. Note this could get really slow if one compiles the same
|
||||
// function with many different shapes. May want to store entries in a
|
||||
// more easily searchable structure.
|
||||
for (auto& entry : entries) {
|
||||
for (CacheEntry& entry : entries) {
|
||||
// Check the inputs match and return if so
|
||||
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
|
||||
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
|
||||
constants == entry.constants) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
@@ -253,7 +253,7 @@ struct CompilerCache {
|
||||
return entries.back();
|
||||
};
|
||||
|
||||
void erase(size_t fun_id) {
|
||||
void erase(std::uintptr_t fun_id) {
|
||||
cache_.erase(fun_id);
|
||||
}
|
||||
|
||||
@@ -263,8 +263,9 @@ struct CompilerCache {
|
||||
// initialized before the compiler cache
|
||||
allocator::allocator();
|
||||
}
|
||||
|
||||
friend CompilerCache& compiler_cache();
|
||||
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
|
||||
std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;
|
||||
};
|
||||
|
||||
CompilerCache& compiler_cache() {
|
||||
@@ -774,7 +775,7 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
if (compile_mode() == CompileMode::disabled ||
|
||||
@@ -833,7 +834,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
};
|
||||
}
|
||||
|
||||
void compile_erase(size_t fun_id) {
|
||||
void compile_erase(std::uintptr_t fun_id) {
|
||||
detail::compiler_cache().erase(fun_id);
|
||||
}
|
||||
|
||||
@@ -845,7 +846,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
if (detail::compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
auto fun_id = detail::getAddress(fun);
|
||||
auto fun_id = detail::get_function_address(fun);
|
||||
return detail::compile(fun, fun_id, shapeless);
|
||||
}
|
||||
|
||||
|
@@ -20,12 +20,12 @@ std::vector<array> vmap_replace(
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
||||
// Erase cached compile functions
|
||||
void compile_erase(size_t fun_id);
|
||||
void compile_erase(std::uintptr_t fun_id);
|
||||
|
||||
// Create an InTracing object during tracing operations to signify to the rest
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
|
Reference in New Issue
Block a user