mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:11:43 +08:00
compile changes if stream changes (#1644)
This commit is contained in:
@@ -211,6 +211,8 @@ std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||
class CompilerCache {
|
||||
public:
|
||||
struct CacheEntry {
|
||||
CacheEntry(Stream stream) : stream(stream) {};
|
||||
Stream stream;
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> outputs;
|
||||
std::vector<array> tape;
|
||||
@@ -227,6 +229,7 @@ class CompilerCache {
|
||||
const std::vector<uint64_t>& constants) {
|
||||
// Find the cache entries for |fun_id|.
|
||||
std::vector<CacheEntry>& entries = cache_[fun_id];
|
||||
|
||||
// Compare if 2 arrays have same shape and dtype.
|
||||
auto has_same_shape_and_dtype = [shapeless](
|
||||
const std::vector<array>& in1,
|
||||
@@ -247,11 +250,16 @@ class CompilerCache {
|
||||
}
|
||||
return true;
|
||||
};
|
||||
// Loop over entries and check inputs match i.e. shapes and types must be
|
||||
// equal. Note this could get really slow if one compiles the same
|
||||
// function with many different shapes. May want to store entries in a
|
||||
// more easily searchable structure.
|
||||
// Loop over entries and check:
|
||||
// - Default stream and device match the entry's default stream
|
||||
// - Inputs match i.e. shapes and types must be equal.
|
||||
auto stream = default_stream(default_device());
|
||||
for (CacheEntry& entry : entries) {
|
||||
// Check that the default stream and device match
|
||||
if (entry.stream != stream) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check the inputs match and return if so
|
||||
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
|
||||
constants == entry.constants) {
|
||||
@@ -259,7 +267,7 @@ class CompilerCache {
|
||||
}
|
||||
}
|
||||
// Otherwise append a new cache entry
|
||||
entries.push_back(CacheEntry{});
|
||||
entries.push_back(CacheEntry{stream});
|
||||
return entries.back();
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user