mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-02 23:31:16 +08:00
ensure io/comm streams are active before eval (#1412)
This commit is contained in:
parent
bd8396fad8
commit
b3f52c9fbe
@ -199,7 +199,6 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
static Stream io_stream = new_stream(Device::cpu);
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto read_task = [out = out,
|
auto read_task = [out = out,
|
||||||
@ -213,7 +212,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
fut.wait();
|
fut.wait();
|
||||||
out.event().signal();
|
out.event().signal();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(io_stream, std::move(signal_task));
|
scheduler::enqueue(io_stream(), std::move(signal_task));
|
||||||
auto& d = metal::device(stream().device);
|
auto& d = metal::device(stream().device);
|
||||||
d.end_encoding(stream().index);
|
d.end_encoding(stream().index);
|
||||||
auto command_buffer = d.get_command_buffer(stream().index);
|
auto command_buffer = d.get_command_buffer(stream().index);
|
||||||
|
@ -255,6 +255,9 @@ Group init(bool strict /* = false */) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure the communication stream is alive before
|
||||||
|
// the graph is evaluated
|
||||||
|
detail::communication_stream();
|
||||||
return Group(global_group);
|
return Group(global_group);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1146,9 +1146,13 @@ class Load : public UnaryPrimitive {
|
|||||||
size_t offset,
|
size_t offset,
|
||||||
bool swap_endianness = false)
|
bool swap_endianness = false)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
reader_(reader),
|
reader_(std::move(reader)),
|
||||||
offset_(offset),
|
offset_(offset),
|
||||||
swap_endianness_(swap_endianness) {}
|
swap_endianness_(swap_endianness) {
|
||||||
|
if (stream.device == Device::gpu) {
|
||||||
|
io_stream();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1156,6 +1160,10 @@ class Load : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(Load)
|
DEFINE_PRINT(Load)
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Stream& io_stream() {
|
||||||
|
static Stream io_stream = new_stream(Device::cpu);
|
||||||
|
return io_stream;
|
||||||
|
};
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
std::shared_ptr<io::Reader> reader_;
|
std::shared_ptr<io::Reader> reader_;
|
||||||
size_t offset_;
|
size_t offset_;
|
||||||
|
Loading…
Reference in New Issue
Block a user