its working

This commit is contained in:
dc-dc-dc 2023-12-18 14:33:23 -05:00
parent 9be3ea69ee
commit 42baa095d1
5 changed files with 29 additions and 4 deletions

View File

@ -31,10 +31,14 @@ void swap_endianess(uint8_t* data_bytes, size_t N) {
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (len_ == 0) {
len_ = out.nbytes();
}
printf("Load::eval: offset= %ld len_ = %ld\n", offset_, len_);
out.set_data(allocator::malloc_or_wait(len_));
reader_->seek(offset_, std::ios_base::beg);
reader_->read(out.data<char>(), out.nbytes());
reader_->read(out.data<char>(), len_);
if (swap_endianness_) {
switch (out.itemsize()) {

View File

@ -225,7 +225,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
auto loaded_array = array(
shape,
dtype,
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
std::make_unique<Load>(
to_stream(s), in_stream, offset, in_stream->tell(), swap_endianness),
std::vector<array>{});
if (col_contiguous) {
loaded_array = transpose(loaded_array, s);

View File

@ -826,6 +826,7 @@ class Load : public Primitive {
Stream stream,
std::shared_ptr<io::Reader> reader,
size_t offset,
size_t len,
bool swap_endianness = false)
: Primitive(stream),
reader_(reader),
@ -841,6 +842,7 @@ class Load : public Primitive {
void eval(const std::vector<array>& inputs, array& out);
std::shared_ptr<io::Reader> reader_;
size_t offset_;
size_t len_;
bool swap_endianness_;
};

View File

@ -189,6 +189,7 @@ std::unordered_map<std::string, array> load_safetensor(
throw std::runtime_error(
"[load_safetensor] Invalid json metadata " + in_stream->label());
}
size_t offset = jsonHeaderLength + 8;
// Parse the json raw data
std::unordered_map<std::string, array> res;
for (auto& [key, obj] : *metadata.getObject()) {
@ -204,9 +205,22 @@ std::unordered_map<std::string, array> load_safetensor(
data_offsets_vec.push_back(offset->getNumber());
}
if (dtype == "F32") {
res.insert({key, zeros(shape_vec, s)});
auto loaded_array = array(
shape_vec,
float32,
std::make_unique<Load>(
to_stream(s),
in_stream,
offset + data_offsets->at(0)->getNumber(),
offset + data_offsets->at(1)->getNumber(),
false),
std::vector<array>{});
res.insert({key, loaded_array});
}
}
// for (auto& [key, arr] : res) {
// arr.eval();
// }
return res;
}

View File

@ -145,6 +145,10 @@ TEST_CASE("test load_safetensor") {
array test = safeDict.at("test");
CHECK_EQ(test.dtype(), float32);
CHECK_EQ(test.shape(), std::vector<int>({4}));
array b = array({1.0, 2.0, 3.0, 4.0});
MESSAGE("test: " << test);
MESSAGE("b: " << b);
CHECK(array_equal(test, b).item<bool>());
}
TEST_CASE("test single array serialization") {