load changes were not needed after all

This commit is contained in:
dc-dc-dc 2023-12-18 17:11:25 -05:00
parent 60132a16de
commit f4d876d35f
5 changed files with 3 additions and 11 deletions

View File

@ -31,13 +31,10 @@ void swap_endianess(uint8_t* data_bytes, size_t N) {
void Load::eval(const std::vector<array>& inputs, array& out) { void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0); assert(inputs.size() == 0);
if (len_ == 0) { out.set_data(allocator::malloc_or_wait(out.nbytes()));
len_ = out.nbytes();
}
out.set_data(allocator::malloc_or_wait(len_));
reader_->seek(offset_, std::ios_base::beg); reader_->seek(offset_, std::ios_base::beg);
reader_->read(out.data<char>(), len_); reader_->read(out.data<char>(), out.nbytes());
if (swap_endianness_) { if (swap_endianness_) {
switch (out.itemsize()) { switch (out.itemsize()) {

View File

@ -225,8 +225,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
auto loaded_array = array( auto loaded_array = array(
shape, shape,
dtype, dtype,
std::make_unique<Load>( std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
to_stream(s), in_stream, offset, in_stream->tell(), swap_endianness),
std::vector<array>{}); std::vector<array>{});
if (col_contiguous) { if (col_contiguous) {
loaded_array = transpose(loaded_array, s); loaded_array = transpose(loaded_array, s);

View File

@ -826,7 +826,6 @@ class Load : public Primitive {
Stream stream, Stream stream,
std::shared_ptr<io::Reader> reader, std::shared_ptr<io::Reader> reader,
size_t offset, size_t offset,
size_t len,
bool swap_endianness = false) bool swap_endianness = false)
: Primitive(stream), : Primitive(stream),
reader_(reader), reader_(reader),
@ -842,7 +841,6 @@ class Load : public Primitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
std::shared_ptr<io::Reader> reader_; std::shared_ptr<io::Reader> reader_;
size_t offset_; size_t offset_;
size_t len_;
bool swap_endianness_; bool swap_endianness_;
}; };

View File

@ -307,7 +307,6 @@ std::unordered_map<std::string, array> load_safetensor(
to_stream(s), to_stream(s),
in_stream, in_stream,
offset + data_offsets->at(0)->getNumber(), offset + data_offsets->at(0)->getNumber(),
offset + data_offsets->at(1)->getNumber(),
false), false),
std::vector<array>{}); std::vector<array>{});
res.insert({key, loaded_array}); res.insert({key, loaded_array});

View File

@ -165,7 +165,6 @@ TEST_CASE("test save_safetensor") {
auto map = std::unordered_map<std::string, array>(); auto map = std::unordered_map<std::string, array>();
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
map.insert({"test2", ones({2, 2})}); map.insert({"test2", ones({2, 2})});
MESSAGE("SAVING");
save_safetensor(file_path, map); save_safetensor(file_path, map);
auto safeDict = load_safetensor(file_path); auto safeDict = load_safetensor(file_path);
CHECK_EQ(safeDict.size(), 2); CHECK_EQ(safeDict.size(), 2);