mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:38:07 +08:00
its working
This commit is contained in:
parent
9be3ea69ee
commit
42baa095d1
@ -31,10 +31,14 @@ void swap_endianess(uint8_t* data_bytes, size_t N) {
|
|||||||
|
|
||||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 0);
|
assert(inputs.size() == 0);
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
if (len_ == 0) {
|
||||||
|
len_ = out.nbytes();
|
||||||
|
}
|
||||||
|
printf("Load::eval: offset= %ld len_ = %ld\n", offset_, len_);
|
||||||
|
out.set_data(allocator::malloc_or_wait(len_));
|
||||||
|
|
||||||
reader_->seek(offset_, std::ios_base::beg);
|
reader_->seek(offset_, std::ios_base::beg);
|
||||||
reader_->read(out.data<char>(), out.nbytes());
|
reader_->read(out.data<char>(), len_);
|
||||||
|
|
||||||
if (swap_endianness_) {
|
if (swap_endianness_) {
|
||||||
switch (out.itemsize()) {
|
switch (out.itemsize()) {
|
||||||
|
@ -225,7 +225,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
auto loaded_array = array(
|
auto loaded_array = array(
|
||||||
shape,
|
shape,
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
std::make_unique<Load>(
|
||||||
|
to_stream(s), in_stream, offset, in_stream->tell(), swap_endianness),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
if (col_contiguous) {
|
if (col_contiguous) {
|
||||||
loaded_array = transpose(loaded_array, s);
|
loaded_array = transpose(loaded_array, s);
|
||||||
|
@ -826,6 +826,7 @@ class Load : public Primitive {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
std::shared_ptr<io::Reader> reader,
|
std::shared_ptr<io::Reader> reader,
|
||||||
size_t offset,
|
size_t offset,
|
||||||
|
size_t len,
|
||||||
bool swap_endianness = false)
|
bool swap_endianness = false)
|
||||||
: Primitive(stream),
|
: Primitive(stream),
|
||||||
reader_(reader),
|
reader_(reader),
|
||||||
@ -841,6 +842,7 @@ class Load : public Primitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
std::shared_ptr<io::Reader> reader_;
|
std::shared_ptr<io::Reader> reader_;
|
||||||
size_t offset_;
|
size_t offset_;
|
||||||
|
size_t len_;
|
||||||
bool swap_endianness_;
|
bool swap_endianness_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -189,6 +189,7 @@ std::unordered_map<std::string, array> load_safetensor(
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
||||||
}
|
}
|
||||||
|
size_t offset = jsonHeaderLength + 8;
|
||||||
// Parse the json raw data
|
// Parse the json raw data
|
||||||
std::unordered_map<std::string, array> res;
|
std::unordered_map<std::string, array> res;
|
||||||
for (auto& [key, obj] : *metadata.getObject()) {
|
for (auto& [key, obj] : *metadata.getObject()) {
|
||||||
@ -204,9 +205,22 @@ std::unordered_map<std::string, array> load_safetensor(
|
|||||||
data_offsets_vec.push_back(offset->getNumber());
|
data_offsets_vec.push_back(offset->getNumber());
|
||||||
}
|
}
|
||||||
if (dtype == "F32") {
|
if (dtype == "F32") {
|
||||||
res.insert({key, zeros(shape_vec, s)});
|
auto loaded_array = array(
|
||||||
|
shape_vec,
|
||||||
|
float32,
|
||||||
|
std::make_unique<Load>(
|
||||||
|
to_stream(s),
|
||||||
|
in_stream,
|
||||||
|
offset + data_offsets->at(0)->getNumber(),
|
||||||
|
offset + data_offsets->at(1)->getNumber(),
|
||||||
|
false),
|
||||||
|
std::vector<array>{});
|
||||||
|
res.insert({key, loaded_array});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// for (auto& [key, arr] : res) {
|
||||||
|
// arr.eval();
|
||||||
|
// }
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,6 +145,10 @@ TEST_CASE("test load_safetensor") {
|
|||||||
array test = safeDict.at("test");
|
array test = safeDict.at("test");
|
||||||
CHECK_EQ(test.dtype(), float32);
|
CHECK_EQ(test.dtype(), float32);
|
||||||
CHECK_EQ(test.shape(), std::vector<int>({4}));
|
CHECK_EQ(test.shape(), std::vector<int>({4}));
|
||||||
|
array b = array({1.0, 2.0, 3.0, 4.0});
|
||||||
|
MESSAGE("test: " << test);
|
||||||
|
MESSAGE("b: " << b);
|
||||||
|
CHECK(array_equal(test, b).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test single array serialization") {
|
TEST_CASE("test single array serialization") {
|
||||||
|
Loading…
Reference in New Issue
Block a user