compile changes if stream changes (#1644)

This commit is contained in:
Awni Hannun 2024-12-03 14:37:44 -08:00 committed by GitHub
parent 9d40e521d7
commit e047fd977d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 14 deletions

View File

@ -211,6 +211,8 @@ std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
class CompilerCache { class CompilerCache {
public: public:
struct CacheEntry { struct CacheEntry {
CacheEntry(Stream stream) : stream(stream) {};
Stream stream;
std::vector<array> inputs; std::vector<array> inputs;
std::vector<array> outputs; std::vector<array> outputs;
std::vector<array> tape; std::vector<array> tape;
@ -227,6 +229,7 @@ class CompilerCache {
const std::vector<uint64_t>& constants) { const std::vector<uint64_t>& constants) {
// Find the cache entries for |fun_id|. // Find the cache entries for |fun_id|.
std::vector<CacheEntry>& entries = cache_[fun_id]; std::vector<CacheEntry>& entries = cache_[fun_id];
// Compare if 2 arrays have same shape and dtype. // Compare if 2 arrays have same shape and dtype.
auto has_same_shape_and_dtype = [shapeless]( auto has_same_shape_and_dtype = [shapeless](
const std::vector<array>& in1, const std::vector<array>& in1,
@ -247,11 +250,16 @@ class CompilerCache {
} }
return true; return true;
}; };
// Loop over entries and check inputs match i.e. shapes and types must be // Loop over entries and check:
// equal. Note this could get really slow if one compiles the same // - Default stream and device match the entry's default stream
// function with many different shapes. May want to store entries in a // - Inputs match i.e. shapes and types must be equal.
// more easily searchable structure. auto stream = default_stream(default_device());
for (CacheEntry& entry : entries) { for (CacheEntry& entry : entries) {
// Check that the default stream and device match
if (entry.stream != stream) {
continue;
}
// Check the inputs match and return if so // Check the inputs match and return if so
if (has_same_shape_and_dtype(inputs, entry.inputs) && if (has_same_shape_and_dtype(inputs, entry.inputs) &&
constants == entry.constants) { constants == entry.constants) {
@ -259,7 +267,7 @@ class CompilerCache {
} }
} }
// Otherwise append a new cache entry // Otherwise append a new cache entry
entries.push_back(CacheEntry{}); entries.push_back(CacheEntry{stream});
return entries.back(); return entries.back();
} }

View File

@ -184,10 +184,6 @@ void init_array(nb::module_& m) {
R"pbdoc( R"pbdoc(
A helper object to apply updates at specific indices. A helper object to apply updates at specific indices.
)pbdoc") )pbdoc")
.def(
nb::init<const array&>(),
"x"_a,
nb::sig("def __init__(self, x: array)"))
.def("__getitem__", &ArrayAt::set_indices, "indices"_a.none()) .def("__getitem__", &ArrayAt::set_indices, "indices"_a.none())
.def("add", &ArrayAt::add, "value"_a) .def("add", &ArrayAt::add, "value"_a)
.def("subtract", &ArrayAt::subtract, "value"_a) .def("subtract", &ArrayAt::subtract, "value"_a)
@ -202,10 +198,6 @@ void init_array(nb::module_& m) {
R"pbdoc( R"pbdoc(
A helper object to iterate over the 1st dimension of an array. A helper object to iterate over the 1st dimension of an array.
)pbdoc") )pbdoc")
.def(
nb::init<const array&>(),
"x"_a,
nb::sig("def __init__(self, x: array)"))
.def("__next__", &ArrayPythonIterator::next) .def("__next__", &ArrayPythonIterator::next)
.def("__iter__", [](const ArrayPythonIterator& it) { return it; }); .def("__iter__", [](const ArrayPythonIterator& it) { return it; });

View File

@ -48,7 +48,6 @@ void init_stream(nb::module_& m) {
R"pbdoc( R"pbdoc(
A stream for running operations on a given device. A stream for running operations on a given device.
)pbdoc") )pbdoc")
.def(nb::init<int, Device>(), "index"_a, "device"_a)
.def_ro("device", &Stream::device) .def_ro("device", &Stream::device)
.def( .def(
"__repr__", "__repr__",

View File

@ -719,3 +719,14 @@ TEST_CASE("test compile strides") {
CHECK_EQ(out.strides().size(), 3); CHECK_EQ(out.strides().size(), 3);
} }
} }
TEST_CASE("test compile change streams") {
auto cfun = compile(simple_fun);
auto out = cfun({array(1.0f), array(2.0f)})[0];
CHECK_EQ(out.primitive().stream(), default_stream(default_device()));
auto s = new_stream(default_device());
StreamContext sctx(s);
out = cfun({array(1.0f), array(2.0f)})[0];
CHECK_EQ(out.primitive().stream(), s);
}