mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
load changes were not needed after all
This commit is contained in:
parent
60132a16de
commit
f4d876d35f
@ -31,13 +31,10 @@ void swap_endianess(uint8_t* data_bytes, size_t N) {
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
if (len_ == 0) {
|
||||
len_ = out.nbytes();
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(len_));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
reader_->seek(offset_, std::ios_base::beg);
|
||||
reader_->read(out.data<char>(), len_);
|
||||
reader_->read(out.data<char>(), out.nbytes());
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
|
@ -225,8 +225,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
auto loaded_array = array(
|
||||
shape,
|
||||
dtype,
|
||||
std::make_unique<Load>(
|
||||
to_stream(s), in_stream, offset, in_stream->tell(), swap_endianness),
|
||||
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||
std::vector<array>{});
|
||||
if (col_contiguous) {
|
||||
loaded_array = transpose(loaded_array, s);
|
||||
|
@ -826,7 +826,6 @@ class Load : public Primitive {
|
||||
Stream stream,
|
||||
std::shared_ptr<io::Reader> reader,
|
||||
size_t offset,
|
||||
size_t len,
|
||||
bool swap_endianness = false)
|
||||
: Primitive(stream),
|
||||
reader_(reader),
|
||||
@ -842,7 +841,6 @@ class Load : public Primitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
std::shared_ptr<io::Reader> reader_;
|
||||
size_t offset_;
|
||||
size_t len_;
|
||||
bool swap_endianness_;
|
||||
};
|
||||
|
||||
|
@ -307,7 +307,6 @@ std::unordered_map<std::string, array> load_safetensor(
|
||||
to_stream(s),
|
||||
in_stream,
|
||||
offset + data_offsets->at(0)->getNumber(),
|
||||
offset + data_offsets->at(1)->getNumber(),
|
||||
false),
|
||||
std::vector<array>{});
|
||||
res.insert({key, loaded_array});
|
||||
|
@ -165,7 +165,6 @@ TEST_CASE("test save_safetensor") {
|
||||
auto map = std::unordered_map<std::string, array>();
|
||||
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
||||
map.insert({"test2", ones({2, 2})});
|
||||
MESSAGE("SAVING");
|
||||
save_safetensor(file_path, map);
|
||||
auto safeDict = load_safetensor(file_path);
|
||||
CHECK_EQ(safeDict.size(), 2);
|
||||
|
Loading…
Reference in New Issue
Block a user