mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
@@ -4,12 +4,14 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -197,7 +199,24 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
static Stream io_stream = new_stream(Device::cpu);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto task = [out = out,
|
||||
offset = offset_,
|
||||
reader = reader_,
|
||||
swap_endianness = swap_endianness_]() mutable {
|
||||
load(out, offset, reader, swap_endianness);
|
||||
out.event().signal();
|
||||
};
|
||||
|
||||
scheduler::enqueue(io_stream, std::move(task));
|
||||
auto& d = metal::device(stream().device);
|
||||
d.end_encoding(stream().index);
|
||||
auto command_buffer = d.get_command_buffer(stream().index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
Reference in New Issue
Block a user