mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 19:51:13 +08:00
parent
fcb65a3897
commit
a6c3b38fba
@ -5,11 +5,9 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/io/load.h"
|
#include "mlx/backend/common/load.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <const uint8_t scalar_size>
|
template <const uint8_t scalar_size>
|
||||||
@ -29,11 +27,14 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
namespace mlx::core {
|
||||||
assert(inputs.size() == 0);
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
reader_->read(out.data<char>(), out.nbytes(), offset_);
|
void load(
|
||||||
|
array& out,
|
||||||
|
size_t offset,
|
||||||
|
const std::shared_ptr<io::Reader>& reader,
|
||||||
|
bool swap_endianness_) {
|
||||||
|
reader->read(out.data<char>(), out.nbytes(), offset);
|
||||||
|
|
||||||
if (swap_endianness_) {
|
if (swap_endianness_) {
|
||||||
switch (out.itemsize()) {
|
switch (out.itemsize()) {
|
||||||
@ -50,4 +51,11 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
load(out, offset_, reader_, swap_endianness_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
14
mlx/backend/common/load.h
Normal file
14
mlx/backend/common/load.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/io/load.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void load(
|
||||||
|
array& out,
|
||||||
|
size_t offset,
|
||||||
|
const std::shared_ptr<io::Reader>& reader,
|
||||||
|
bool swap_endianess);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -4,12 +4,14 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/load.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/slicing.h"
|
#include "mlx/backend/metal/slicing.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -197,7 +199,24 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
static Stream io_stream = new_stream(Device::cpu);
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto task = [out = out,
|
||||||
|
offset = offset_,
|
||||||
|
reader = reader_,
|
||||||
|
swap_endianness = swap_endianness_]() mutable {
|
||||||
|
load(out, offset, reader, swap_endianness);
|
||||||
|
out.event().signal();
|
||||||
|
};
|
||||||
|
|
||||||
|
scheduler::enqueue(io_stream, std::move(task));
|
||||||
|
auto& d = metal::device(stream().device);
|
||||||
|
d.end_encoding(stream().index);
|
||||||
|
auto command_buffer = d.get_command_buffer(stream().index);
|
||||||
|
command_buffer->encodeWait(
|
||||||
|
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||||
|
out.event().value());
|
||||||
}
|
}
|
||||||
|
|
||||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
Loading…
Reference in New Issue
Block a user