mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
compile changes if stream changes (#1644)
This commit is contained in:
parent
9d40e521d7
commit
e047fd977d
@ -211,6 +211,8 @@ std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
|||||||
class CompilerCache {
|
class CompilerCache {
|
||||||
public:
|
public:
|
||||||
struct CacheEntry {
|
struct CacheEntry {
|
||||||
|
CacheEntry(Stream stream) : stream(stream) {};
|
||||||
|
Stream stream;
|
||||||
std::vector<array> inputs;
|
std::vector<array> inputs;
|
||||||
std::vector<array> outputs;
|
std::vector<array> outputs;
|
||||||
std::vector<array> tape;
|
std::vector<array> tape;
|
||||||
@ -227,6 +229,7 @@ class CompilerCache {
|
|||||||
const std::vector<uint64_t>& constants) {
|
const std::vector<uint64_t>& constants) {
|
||||||
// Find the cache entries for |fun_id|.
|
// Find the cache entries for |fun_id|.
|
||||||
std::vector<CacheEntry>& entries = cache_[fun_id];
|
std::vector<CacheEntry>& entries = cache_[fun_id];
|
||||||
|
|
||||||
// Compare if 2 arrays have same shape and dtype.
|
// Compare if 2 arrays have same shape and dtype.
|
||||||
auto has_same_shape_and_dtype = [shapeless](
|
auto has_same_shape_and_dtype = [shapeless](
|
||||||
const std::vector<array>& in1,
|
const std::vector<array>& in1,
|
||||||
@ -247,11 +250,16 @@ class CompilerCache {
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
// Loop over entries and check inputs match i.e. shapes and types must be
|
// Loop over entries and check:
|
||||||
// equal. Note this could get really slow if one compiles the same
|
// - Default stream and device match the entry's default stream
|
||||||
// function with many different shapes. May want to store entries in a
|
// - Inputs match i.e. shapes and types must be equal.
|
||||||
// more easily searchable structure.
|
auto stream = default_stream(default_device());
|
||||||
for (CacheEntry& entry : entries) {
|
for (CacheEntry& entry : entries) {
|
||||||
|
// Check that the default stream and device match
|
||||||
|
if (entry.stream != stream) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Check the inputs match and return if so
|
// Check the inputs match and return if so
|
||||||
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
|
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
|
||||||
constants == entry.constants) {
|
constants == entry.constants) {
|
||||||
@ -259,7 +267,7 @@ class CompilerCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Otherwise append a new cache entry
|
// Otherwise append a new cache entry
|
||||||
entries.push_back(CacheEntry{});
|
entries.push_back(CacheEntry{stream});
|
||||||
return entries.back();
|
return entries.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -184,10 +184,6 @@ void init_array(nb::module_& m) {
|
|||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A helper object to apply updates at specific indices.
|
A helper object to apply updates at specific indices.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def(
|
|
||||||
nb::init<const array&>(),
|
|
||||||
"x"_a,
|
|
||||||
nb::sig("def __init__(self, x: array)"))
|
|
||||||
.def("__getitem__", &ArrayAt::set_indices, "indices"_a.none())
|
.def("__getitem__", &ArrayAt::set_indices, "indices"_a.none())
|
||||||
.def("add", &ArrayAt::add, "value"_a)
|
.def("add", &ArrayAt::add, "value"_a)
|
||||||
.def("subtract", &ArrayAt::subtract, "value"_a)
|
.def("subtract", &ArrayAt::subtract, "value"_a)
|
||||||
@ -202,10 +198,6 @@ void init_array(nb::module_& m) {
|
|||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A helper object to iterate over the 1st dimension of an array.
|
A helper object to iterate over the 1st dimension of an array.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def(
|
|
||||||
nb::init<const array&>(),
|
|
||||||
"x"_a,
|
|
||||||
nb::sig("def __init__(self, x: array)"))
|
|
||||||
.def("__next__", &ArrayPythonIterator::next)
|
.def("__next__", &ArrayPythonIterator::next)
|
||||||
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
|
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
|
||||||
|
|
||||||
|
@ -48,7 +48,6 @@ void init_stream(nb::module_& m) {
|
|||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A stream for running operations on a given device.
|
A stream for running operations on a given device.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def(nb::init<int, Device>(), "index"_a, "device"_a)
|
|
||||||
.def_ro("device", &Stream::device)
|
.def_ro("device", &Stream::device)
|
||||||
.def(
|
.def(
|
||||||
"__repr__",
|
"__repr__",
|
||||||
|
@ -719,3 +719,14 @@ TEST_CASE("test compile strides") {
|
|||||||
CHECK_EQ(out.strides().size(), 3);
|
CHECK_EQ(out.strides().size(), 3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile change streams") {
|
||||||
|
auto cfun = compile(simple_fun);
|
||||||
|
auto out = cfun({array(1.0f), array(2.0f)})[0];
|
||||||
|
CHECK_EQ(out.primitive().stream(), default_stream(default_device()));
|
||||||
|
|
||||||
|
auto s = new_stream(default_device());
|
||||||
|
StreamContext sctx(s);
|
||||||
|
out = cfun({array(1.0f), array(2.0f)})[0];
|
||||||
|
CHECK_EQ(out.primitive().stream(), s);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user