ensure io/comm streams are active before eval (#1412)

This commit is contained in:
Awni Hannun 2024-09-14 06:17:36 -07:00 committed by GitHub
parent bd8396fad8
commit b3f52c9fbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 4 deletions

View File

@ -199,7 +199,6 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
static Stream io_stream = new_stream(Device::cpu);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = out,
@ -213,7 +212,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
fut.wait();
out.event().signal();
};
scheduler::enqueue(io_stream, std::move(signal_task));
scheduler::enqueue(io_stream(), std::move(signal_task));
auto& d = metal::device(stream().device);
d.end_encoding(stream().index);
auto command_buffer = d.get_command_buffer(stream().index);

View File

@ -255,6 +255,9 @@ Group init(bool strict /* = false */) {
}
}
// Ensure the communication stream is alive before
// the graph is evaluated
detail::communication_stream();
return Group(global_group);
}

View File

@ -1146,9 +1146,13 @@ class Load : public UnaryPrimitive {
size_t offset,
bool swap_endianness = false)
: UnaryPrimitive(stream),
reader_(reader),
reader_(std::move(reader)),
offset_(offset),
swap_endianness_(swap_endianness) {}
swap_endianness_(swap_endianness) {
if (stream.device == Device::gpu) {
io_stream();
}
}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1156,6 +1160,10 @@ class Load : public UnaryPrimitive {
DEFINE_PRINT(Load)
private:
Stream& io_stream() {
static Stream io_stream = new_stream(Device::cpu);
return io_stream;
};
void eval(const std::vector<array>& inputs, array& out);
std::shared_ptr<io::Reader> reader_;
size_t offset_;