mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
ensure io/comm streams are active before eval (#1412)
This commit is contained in:
@@ -1146,9 +1146,13 @@ class Load : public UnaryPrimitive {
|
||||
size_t offset,
|
||||
bool swap_endianness = false)
|
||||
: UnaryPrimitive(stream),
|
||||
reader_(reader),
|
||||
reader_(std::move(reader)),
|
||||
offset_(offset),
|
||||
swap_endianness_(swap_endianness) {}
|
||||
swap_endianness_(swap_endianness) {
|
||||
if (stream.device == Device::gpu) {
|
||||
io_stream();
|
||||
}
|
||||
}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -1156,6 +1160,10 @@ class Load : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Load)
|
||||
|
||||
private:
|
||||
Stream& io_stream() {
|
||||
static Stream io_stream = new_stream(Device::cpu);
|
||||
return io_stream;
|
||||
};
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
std::shared_ptr<io::Reader> reader_;
|
||||
size_t offset_;
|
||||
|
||||
Reference in New Issue
Block a user