mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
GGUF: Load and save metadata (#446)
* gguf metadata --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
6589c869d6
commit
ddf50113c5
55
mlx/io.h
Normal file
55
mlx/io.h
Normal file
@ -0,0 +1,55 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file, array a);
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
|
||||
/** Load array from file in .npy format */
|
||||
array load(const std::string& file, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s = {});
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> in_stream,
|
||||
std::unordered_map<std::string, array>);
|
||||
void save_safetensors(
|
||||
const std::string& file,
|
||||
std::unordered_map<std::string, array>);
|
||||
|
||||
using MetaData =
|
||||
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
|
||||
|
||||
/** Load array map and metadata from .gguf file format */
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
load_gguf(const std::string& file, StreamOrDevice s = {});
|
||||
|
||||
void save_gguf(
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array> array_map,
|
||||
std::unordered_map<std::string, MetaData> meta_data = {});
|
||||
|
||||
} // namespace mlx::core
|
309
mlx/io/gguf.cpp
309
mlx/io/gguf.cpp
@ -1,9 +1,12 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/io.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
extern "C" {
|
||||
@ -12,6 +15,9 @@ extern "C" {
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// https://github.com/antirez/gguf-tools/blob/af7d88d808a7608a33723fba067036202910acb3/gguflib.h#L102-L108
|
||||
constexpr int gguf_array_header_size = 12;
|
||||
|
||||
std::optional<uint32_t> dtype_to_gguf_tensor_type(const Dtype& dtype) {
|
||||
switch (dtype) {
|
||||
case float32:
|
||||
@ -46,7 +52,7 @@ std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
|
||||
std::pair<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
|
||||
std::optional<Dtype> equivalent_dtype = gguf_type_to_dtype(tensor->type);
|
||||
// If there's an equivalent type, we can simply copy.
|
||||
if (equivalent_dtype.has_value()) {
|
||||
@ -70,15 +76,132 @@ std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
|
||||
return {buffer, float16};
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> load_gguf(
|
||||
const std::string& file,
|
||||
StreamOrDevice s) {
|
||||
std::unordered_map<std::string, array> result;
|
||||
gguf_ctx* ctx = gguf_open(file.c_str());
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
void set_mx_value_from_gguf(
|
||||
gguf_ctx* ctx,
|
||||
uint32_t type,
|
||||
gguf_value* val,
|
||||
MetaData& value) {
|
||||
switch (type) {
|
||||
case GGUF_VALUE_TYPE_UINT8:
|
||||
value = array(val->uint8, uint8);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_INT8:
|
||||
value = array(val->int8, int8);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_UINT16:
|
||||
value = array(val->uint16, uint16);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_INT16:
|
||||
value = array(val->int16, int16);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_UINT32:
|
||||
value = array(val->uint32, uint32);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_INT32:
|
||||
value = array(val->int32, int32);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_UINT64:
|
||||
value = array(val->uint64, uint64);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_INT64:
|
||||
value = array(val->int64, int64);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_FLOAT32:
|
||||
value = array(val->float32, float32);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_BOOL:
|
||||
value = array(val->boolval, bool_);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_STRING:
|
||||
value =
|
||||
std::string(val->string.string, static_cast<int>(val->string.len));
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_FLOAT64:
|
||||
value = array(val->float64, float32);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_ARRAY: {
|
||||
ctx->off += gguf_array_header_size; // Skip header
|
||||
char* data = reinterpret_cast<char*>(val) + gguf_array_header_size;
|
||||
auto size = static_cast<int>(val->array.len);
|
||||
if (val->array.type == GGUF_VALUE_TYPE_ARRAY) {
|
||||
throw std::invalid_argument(
|
||||
"[load_gguf] Only supports loading 1-layer of nested arrays.");
|
||||
}
|
||||
switch (val->array.type) {
|
||||
case GGUF_VALUE_TYPE_UINT8:
|
||||
value = array(reinterpret_cast<uint8_t*>(data), {size}, uint8);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_INT8:
|
||||
value = array(reinterpret_cast<int8_t*>(data), {size}, int8);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_UINT16:
|
||||
value = array(reinterpret_cast<uint16_t*>(data), {size}, uint16);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_INT16:
|
||||
value = array(reinterpret_cast<int16_t*>(data), {size}, int16);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_UINT32:
|
||||
value = array(reinterpret_cast<uint32_t*>(data), {size}, uint32);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_INT32:
|
||||
value = array(reinterpret_cast<int32_t*>(data), {size}, int32);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_UINT64:
|
||||
value = array(reinterpret_cast<uint64_t*>(data), {size}, uint64);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_INT64:
|
||||
value = array(reinterpret_cast<uint64_t*>(data), {size}, int64);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_FLOAT32:
|
||||
value = array(reinterpret_cast<float*>(data), {size}, float32);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_BOOL:
|
||||
value = array(reinterpret_cast<bool*>(data), {size}, bool_);
|
||||
break;
|
||||
case GGUF_VALUE_TYPE_STRING: {
|
||||
std::vector<std::string> strs(size);
|
||||
for (auto& str : strs) {
|
||||
auto str_val = reinterpret_cast<gguf_string*>(data);
|
||||
data += (str_val->len + sizeof(gguf_string));
|
||||
str = std::string(str_val->string, static_cast<int>(str_val->len));
|
||||
ctx->off += (str_val->len + sizeof(gguf_string));
|
||||
}
|
||||
value = std::move(strs);
|
||||
break;
|
||||
}
|
||||
case GGUF_VALUE_TYPE_FLOAT64:
|
||||
value = array(reinterpret_cast<double*>(data), {size}, float32);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[load_gguf] Multiple levels of nested arrays are not supported.");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("[load_gguf] Received unexpected type.");
|
||||
break;
|
||||
}
|
||||
gguf_skip_key_values_section(ctx);
|
||||
if (type == GGUF_VALUE_TYPE_STRING) {
|
||||
ctx->off += (sizeof(gguf_string) + std::get<std::string>(value).size());
|
||||
} else if (auto pv = std::get_if<array>(&value); pv) {
|
||||
ctx->off += pv->nbytes();
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) {
|
||||
std::unordered_map<std::string, MetaData> metadata;
|
||||
gguf_key key;
|
||||
while (gguf_get_key(ctx, &key)) {
|
||||
std::string key_name = std::string(key.name, key.namelen);
|
||||
auto& val = metadata.insert({key_name, MetaData{}}).first->second;
|
||||
set_mx_value_from_gguf(ctx, key.type, key.val, val);
|
||||
}
|
||||
return metadata;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
std::unordered_map<std::string, array> array_map;
|
||||
gguf_tensor tensor;
|
||||
while (gguf_get_tensor(ctx, &tensor)) {
|
||||
std::vector<int> shape;
|
||||
@ -89,27 +212,181 @@ std::unordered_map<std::string, array> load_gguf(
|
||||
const auto& [data, dtype] = extract_tensor_data(&tensor);
|
||||
array loaded_array = array(data, shape, dtype);
|
||||
std::string name = std::string(tensor.name, tensor.namelen);
|
||||
result.insert({name, loaded_array});
|
||||
array_map.insert({name, loaded_array});
|
||||
}
|
||||
gguf_close(ctx);
|
||||
return result;
|
||||
return array_map;
|
||||
}
|
||||
|
||||
void save_gguf(std::string file, std::unordered_map<std::string, array> a) {
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
gguf_ctx* ctx = gguf_open(file.c_str());
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
}
|
||||
auto metadata = load_metadata(ctx);
|
||||
auto arrays = load_arrays(ctx);
|
||||
gguf_close(ctx);
|
||||
return {arrays, metadata};
|
||||
}
|
||||
|
||||
void append_kv_array(
|
||||
gguf_ctx* ctx,
|
||||
const std::string& key,
|
||||
array& val,
|
||||
uint32_t gguf_type) {
|
||||
if (val.ndim() == 1) {
|
||||
size_t gguf_size = val.nbytes() + gguf_array_header_size;
|
||||
std::vector<char> val_vec(gguf_size);
|
||||
gguf_value* gguf_val = reinterpret_cast<gguf_value*>(val_vec.data());
|
||||
gguf_val->array.type = gguf_type;
|
||||
gguf_val->array.len = val.size();
|
||||
memcpy(
|
||||
val_vec.data() + gguf_array_header_size,
|
||||
val.data<char>(),
|
||||
val.nbytes());
|
||||
gguf_append_kv(
|
||||
ctx,
|
||||
key.c_str(),
|
||||
key.length(),
|
||||
GGUF_VALUE_TYPE_ARRAY,
|
||||
reinterpret_cast<void*>(val_vec.data()),
|
||||
gguf_size);
|
||||
} else {
|
||||
gguf_append_kv(
|
||||
ctx,
|
||||
key.c_str(),
|
||||
key.length(),
|
||||
gguf_type,
|
||||
reinterpret_cast<void*>(val.data<char>()),
|
||||
val.nbytes());
|
||||
}
|
||||
}
|
||||
|
||||
void save_gguf(
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array> array_map,
|
||||
std::unordered_map<std::string, MetaData> metadata /* = {} */) {
|
||||
// Add .gguf to file name if it is not there
|
||||
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
|
||||
file += ".gguf";
|
||||
}
|
||||
|
||||
gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE);
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[save_gguf] gguf_create failed");
|
||||
}
|
||||
|
||||
auto string_to_gguf = [](char* dst, const std::string& src) {
|
||||
gguf_string* val = reinterpret_cast<gguf_string*>(dst);
|
||||
val->len = src.length();
|
||||
memcpy(val->string, src.c_str(), src.length());
|
||||
};
|
||||
|
||||
// Save any meta data
|
||||
for (auto& [key, value] : metadata) {
|
||||
if (auto pv = std::get_if<std::string>(&value); pv) {
|
||||
const std::string& str = *pv;
|
||||
size_t size = sizeof(gguf_string) + str.length();
|
||||
std::vector<char> val_vec(size);
|
||||
string_to_gguf(val_vec.data(), str);
|
||||
gguf_append_kv(
|
||||
ctx,
|
||||
key.c_str(),
|
||||
key.length(),
|
||||
GGUF_VALUE_TYPE_STRING,
|
||||
static_cast<void*>(val_vec.data()),
|
||||
size);
|
||||
} else if (auto pv = std::get_if<std::vector<std::string>>(&value); pv) {
|
||||
const auto& str_vec = *pv;
|
||||
auto mem_size = std::accumulate(
|
||||
str_vec.begin(), str_vec.end(), 0, [](size_t accum, const auto& s) {
|
||||
return accum + s.size();
|
||||
});
|
||||
mem_size += str_vec.size() * sizeof(gguf_string) + gguf_array_header_size;
|
||||
std::vector<char> val_vec(mem_size);
|
||||
gguf_value* val = reinterpret_cast<gguf_value*>(val_vec.data());
|
||||
val->array.type = GGUF_VALUE_TYPE_STRING;
|
||||
val->array.len = str_vec.size();
|
||||
auto str_ptr = val_vec.data() + gguf_array_header_size;
|
||||
for (auto& str : str_vec) {
|
||||
string_to_gguf(str_ptr, str);
|
||||
str_ptr += str.length() + sizeof(gguf_string);
|
||||
}
|
||||
gguf_append_kv(
|
||||
ctx,
|
||||
key.c_str(),
|
||||
key.length(),
|
||||
GGUF_VALUE_TYPE_ARRAY,
|
||||
static_cast<void*>(val),
|
||||
mem_size);
|
||||
} else if (auto pv = std::get_if<array>(&value); pv) {
|
||||
array v = *pv;
|
||||
if (v.ndim() > 1) {
|
||||
throw std::runtime_error(
|
||||
"[save_gguf] Cannot save arrays with more than one dimension.");
|
||||
}
|
||||
if (v.size() == 0) {
|
||||
throw std::runtime_error("[save_gguf] Cannot save empty arrays.");
|
||||
}
|
||||
|
||||
eval(v);
|
||||
if (!v.flags().row_contiguous) {
|
||||
v = reshape(flatten(v), v.shape());
|
||||
}
|
||||
if (!v.flags().row_contiguous) {
|
||||
throw std::runtime_error(
|
||||
"[save_gguf] Cannot save non contiguous arrays.");
|
||||
}
|
||||
switch (v.dtype()) {
|
||||
case float32:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_FLOAT32);
|
||||
break;
|
||||
case int64:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT64);
|
||||
break;
|
||||
case int32:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT32);
|
||||
break;
|
||||
case int16:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT16);
|
||||
break;
|
||||
case int8:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT8);
|
||||
break;
|
||||
case uint64:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT64);
|
||||
break;
|
||||
case uint32:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT32);
|
||||
break;
|
||||
case uint16:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT16);
|
||||
break;
|
||||
case uint8:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT8);
|
||||
break;
|
||||
case bool_:
|
||||
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_BOOL);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream msg;
|
||||
msg << "[save_gguf] array type " << v.dtype()
|
||||
<< " not support for metadata.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[save_gguf] Received unexpected type in metadata");
|
||||
}
|
||||
}
|
||||
|
||||
// Tensor offsets are relative to data section, so we start at offset 0.
|
||||
uint64_t tensor_offset = 0;
|
||||
|
||||
// First, append the tensor info
|
||||
for (auto& [key, arr] : a) {
|
||||
for (auto& [key, arr] : array_map) {
|
||||
arr.eval();
|
||||
|
||||
// Try to make it row contiguous
|
||||
@ -154,7 +431,7 @@ void save_gguf(std::string file, std::unordered_map<std::string, array> a) {
|
||||
}
|
||||
|
||||
// Then, append the tensor weights
|
||||
for (const auto& [key, arr] : a) {
|
||||
for (const auto& [key, arr] : array_map) {
|
||||
if (!gguf_append_tensor_data(ctx, (void*)arr.data<void>(), arr.nbytes())) {
|
||||
throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed");
|
||||
}
|
||||
|
@ -3,8 +3,8 @@
|
||||
#include <json.hpp>
|
||||
#include <stack>
|
||||
|
||||
#include "mlx/io.h"
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/io.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
|
46
mlx/ops.h
46
mlx/ops.h
@ -1,14 +1,13 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "io/load.h"
|
||||
#include "stream.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -1040,20 +1039,6 @@ array conv2d(
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Serialization operations */
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file, array a);
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
|
||||
/** Load array from file in .npy format */
|
||||
array load(const std::string& file, StreamOrDevice s = {});
|
||||
|
||||
/** Quantized matmul multiplies x with a quantized matrix w*/
|
||||
array quantized_matmul(
|
||||
const array& x,
|
||||
@ -1100,28 +1085,6 @@ array outer(const array& a, const array& b, StreamOrDevice s = {});
|
||||
/** Compute the inner product of two vectors. */
|
||||
array inner(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s = {});
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> in_stream,
|
||||
std::unordered_map<std::string, array>);
|
||||
void save_safetensors(
|
||||
const std::string& file,
|
||||
std::unordered_map<std::string, array>);
|
||||
|
||||
/** Load array map from .gguf file format */
|
||||
std::unordered_map<std::string, array> load_gguf(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
void save_gguf(std::string file, std::unordered_map<std::string, array> a);
|
||||
|
||||
/** Compute D = beta * C + alpha * (A @ B) */
|
||||
array addmm(
|
||||
array c,
|
||||
@ -1130,4 +1093,5 @@ array addmm(
|
||||
const float& alpha = 1.f,
|
||||
const float& beta = 1.f,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -181,9 +181,10 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
"[load_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_gguf_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s) {
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(py::cast<std::string>(file), s);
|
||||
}
|
||||
@ -246,9 +247,10 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||
"[load_npy] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
DictOrArray mlx_load_helper(
|
||||
LoadOutputTypes mlx_load_helper(
|
||||
py::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s) {
|
||||
if (!format.has_value()) {
|
||||
std::string fname;
|
||||
@ -268,6 +270,10 @@ DictOrArray mlx_load_helper(
|
||||
format.emplace(fname.substr(ext + 1));
|
||||
}
|
||||
|
||||
if (return_metadata && format.value() != "gguf") {
|
||||
throw std::invalid_argument(
|
||||
"[load] metadata not supported for format " + format.value());
|
||||
}
|
||||
if (format.value() == "safetensors") {
|
||||
return mlx_load_safetensor_helper(file, s);
|
||||
} else if (format.value() == "npz") {
|
||||
@ -275,7 +281,12 @@ DictOrArray mlx_load_helper(
|
||||
} else if (format.value() == "npy") {
|
||||
return mlx_load_npy_helper(file, s);
|
||||
} else if (format.value() == "gguf") {
|
||||
return mlx_load_gguf_helper(file, s);
|
||||
auto [weights, metadata] = mlx_load_gguf_helper(file, s);
|
||||
if (return_metadata) {
|
||||
return std::make_pair(weights, metadata);
|
||||
} else {
|
||||
return weights;
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("[load] Unknown file format " + format.value());
|
||||
}
|
||||
@ -448,10 +459,19 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||
"[save_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
void mlx_save_gguf_helper(py::object file, py::dict d) {
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
void mlx_save_gguf_helper(
|
||||
py::object file,
|
||||
py::dict a,
|
||||
std::optional<py::dict> m) {
|
||||
auto arrays_map = a.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||
if (m) {
|
||||
auto metadata_map =
|
||||
m.value().cast<std::unordered_map<std::string, MetaData>>();
|
||||
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||
} else {
|
||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -7,26 +7,36 @@
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/io.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlx::core;
|
||||
|
||||
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||
using LoadOutputTypes = std::variant<
|
||||
array,
|
||||
std::unordered_map<std::string, array>,
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>>;
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_gguf_helper(
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s);
|
||||
void mlx_save_gguf_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_gguf_helper(py::object file, py::dict d);
|
||||
py::dict d,
|
||||
std::optional<py::dict> m);
|
||||
|
||||
DictOrArray mlx_load_helper(
|
||||
LoadOutputTypes mlx_load_helper(
|
||||
py::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_helper(py::object file, array a);
|
||||
void mlx_savez_helper(
|
||||
|
@ -1867,11 +1867,11 @@ void init_ops(py::module_& m) {
|
||||
isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return a boolean array indicating which elements are positive infinity.
|
||||
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Union[None, Stream, Device]): Optional stream or device.
|
||||
|
||||
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are positive infinity.
|
||||
)pbdoc");
|
||||
@ -1886,11 +1886,11 @@ void init_ops(py::module_& m) {
|
||||
isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return a boolean array indicating which elements are negative infinity.
|
||||
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Union[None, Stream, Device]): Optional stream or device.
|
||||
|
||||
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are negative infinity.
|
||||
)pbdoc");
|
||||
@ -3117,10 +3117,11 @@ void init_ops(py::module_& m) {
|
||||
"file"_a,
|
||||
py::pos_only(),
|
||||
"format"_a = none,
|
||||
"return_metadata"_a = false,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||
load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||
|
||||
Load array(s) from a binary file.
|
||||
|
||||
@ -3131,10 +3132,15 @@ void init_ops(py::module_& m) {
|
||||
format (str, optional): Format of the file. If ``None``, the format
|
||||
is inferred from the file extension. Supported formats: ``npy``,
|
||||
``npz``, and ``safetensors``. Default: ``None``.
|
||||
return_metadata (bool, optional): Load the metadata for formats which
|
||||
support matadata. The metadata will be returned as an additional
|
||||
dictionary.
|
||||
Returns:
|
||||
result (array, dict):
|
||||
A single array if loading from a ``.npy`` file or a dict mapping
|
||||
names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
|
||||
If ``return_metadata` is ``True`` an additional dictionary of metadata
|
||||
will be returned.
|
||||
|
||||
Warning:
|
||||
|
||||
@ -3164,8 +3170,9 @@ void init_ops(py::module_& m) {
|
||||
&mlx_save_gguf_helper,
|
||||
"file"_a,
|
||||
"arrays"_a,
|
||||
"metadata"_a = none,
|
||||
R"pbdoc(
|
||||
save_gguf(file: str, arrays: Dict[str, array])
|
||||
save_gguf(file: str, arrays: Dict[str, array], metadata: Dict[str, Union[array, str, List[str]]])
|
||||
|
||||
Save array(s) to a binary file in ``.gguf`` format.
|
||||
|
||||
@ -3175,6 +3182,9 @@ void init_ops(py::module_& m) {
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||
metadata (dict(str, Union[array, str, list(str)])): The dictionary of
|
||||
metadata to be saved. The values can be a scalar or 1D obj:`array`,
|
||||
a :obj:`str`, or a :obj:`list` of :obj:`str`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"where",
|
||||
@ -3499,7 +3509,7 @@ void init_ops(py::module_& m) {
|
||||
c (array): Input array or scalar.
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
alpha (float, optional): Scaling factor for the
|
||||
alpha (float, optional): Scaling factor for the
|
||||
matrix product of ``a`` and ``b`` (default: ``1``)
|
||||
beta (float, optional): Scaling factor for ``c`` (default: ``1``)
|
||||
|
||||
|
@ -576,8 +576,8 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertListEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-5).item())
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
@ -117,6 +117,115 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||
)
|
||||
|
||||
def test_save_and_load_gguf_metadata_basic(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
||||
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
||||
metadata = {}
|
||||
|
||||
# Empty works
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
|
||||
# Loads without the metadata
|
||||
load_dict = mx.load(save_file_mlx)
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
|
||||
# Loads empty metadata
|
||||
load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
self.assertEqual(len(meta_load_dict), 0)
|
||||
|
||||
# Loads string metadata
|
||||
metadata = {"meta": "data"}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
self.assertEqual(len(meta_load_dict), 1)
|
||||
self.assertTrue("meta" in meta_load_dict)
|
||||
self.assertEqual(meta_load_dict["meta"], "data")
|
||||
|
||||
def test_save_and_load_gguf_metadata_arrays(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
||||
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
||||
|
||||
# Test scalars and one dimensional arrays
|
||||
for t in [
|
||||
mx.uint8,
|
||||
mx.int8,
|
||||
mx.uint16,
|
||||
mx.int16,
|
||||
mx.uint32,
|
||||
mx.int32,
|
||||
mx.uint64,
|
||||
mx.int64,
|
||||
mx.float32,
|
||||
]:
|
||||
for shape in [(), (2,)]:
|
||||
arr = mx.random.uniform(shape=shape).astype(t)
|
||||
metadata = {"meta": arr}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 1)
|
||||
self.assertTrue("meta" in meta_load_dict)
|
||||
self.assertTrue(mx.array_equal(meta_load_dict["meta"], arr))
|
||||
self.assertEqual(meta_load_dict["meta"].dtype, arr.dtype)
|
||||
|
||||
for t in [mx.float16, mx.bfloat16, mx.complex64]:
|
||||
with self.assertRaises(ValueError):
|
||||
arr = mx.array(1, t)
|
||||
metadata = {"meta": arr}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
|
||||
def test_save_and_load_gguf_metadata_mixed(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
||||
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
||||
|
||||
# Test string and array
|
||||
arr = mx.array(1.5)
|
||||
metadata = {"meta1": arr, "meta2": "data"}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 2)
|
||||
self.assertTrue("meta1" in meta_load_dict)
|
||||
self.assertTrue(mx.array_equal(meta_load_dict["meta1"], arr))
|
||||
self.assertEqual(meta_load_dict["meta1"].dtype, arr.dtype)
|
||||
self.assertTrue("meta2" in meta_load_dict)
|
||||
self.assertEqual(meta_load_dict["meta2"], "data")
|
||||
|
||||
# Test list of strings
|
||||
metadata = {"meta": ["data1", "data2", "data345"]}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 1)
|
||||
self.assertEqual(meta_load_dict["meta"], metadata["meta"])
|
||||
|
||||
# Test a combination of stuff
|
||||
metadata = {
|
||||
"meta1": ["data1", "data2", "data345"],
|
||||
"meta2": mx.array([1, 2, 3, 4]),
|
||||
"meta3": "data",
|
||||
"meta4": mx.array(1.5),
|
||||
}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 4)
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, mx.array):
|
||||
self.assertTrue(mx.array_equal(meta_load_dict[k], v))
|
||||
else:
|
||||
self.assertEqual(meta_load_dict[k], v)
|
||||
|
||||
def test_save_and_load_fs(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
@ -37,33 +37,161 @@ TEST_CASE("test save_safetensors") {
|
||||
TEST_CASE("test gguf") {
|
||||
std::string file_path = get_temp_file("test_arr.gguf");
|
||||
using dict = std::unordered_map<std::string, array>;
|
||||
dict map = {
|
||||
dict original_weights = {
|
||||
{"test", array({1.0f, 2.0f, 3.0f, 4.0f})},
|
||||
{"test2", reshape(arange(6), {3, 2})}};
|
||||
|
||||
save_gguf(file_path, map);
|
||||
auto loaded = load_gguf(file_path);
|
||||
CHECK_EQ(loaded.size(), 2);
|
||||
CHECK_EQ(loaded.count("test"), 1);
|
||||
CHECK_EQ(loaded.count("test2"), 1);
|
||||
for (auto [k, v] : loaded) {
|
||||
CHECK(array_equal(v, map.at(k)).item<bool>());
|
||||
{
|
||||
// Check saving loading just arrays, no metadata
|
||||
save_gguf(file_path, original_weights);
|
||||
auto [loaded_weights, loaded_metadata] = load_gguf(file_path);
|
||||
CHECK_EQ(loaded_metadata.size(), 0);
|
||||
CHECK_EQ(loaded_weights.size(), 2);
|
||||
CHECK_EQ(loaded_weights.count("test"), 1);
|
||||
CHECK_EQ(loaded_weights.count("test2"), 1);
|
||||
for (auto [k, v] : loaded_weights) {
|
||||
CHECK(array_equal(v, original_weights.at(k)).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
// Test saving and loading string metadata
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
original_metadata.insert({"test_str", "my string"});
|
||||
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
auto [loaded_weights, loaded_metadata] = load_gguf(file_path);
|
||||
CHECK_EQ(loaded_metadata.size(), 1);
|
||||
CHECK_EQ(loaded_metadata.count("test_str"), 1);
|
||||
CHECK_EQ(std::get<std::string>(loaded_metadata.at("test_str")), "my string");
|
||||
|
||||
CHECK_EQ(loaded_weights.size(), 2);
|
||||
CHECK_EQ(loaded_weights.count("test"), 1);
|
||||
CHECK_EQ(loaded_weights.count("test2"), 1);
|
||||
for (auto [k, v] : loaded_weights) {
|
||||
CHECK(array_equal(v, original_weights.at(k)).item<bool>());
|
||||
}
|
||||
|
||||
std::vector<Dtype> unsupported_types = {
|
||||
bool_, uint8, uint32, uint64, int64, bfloat16, complex64};
|
||||
for (auto t : unsupported_types) {
|
||||
dict to_save = {{"test", astype(arange(5), t)}};
|
||||
CHECK_THROWS(save_gguf(file_path, to_save));
|
||||
CHECK_THROWS(save_gguf(file_path, to_save, original_metadata));
|
||||
}
|
||||
|
||||
std::vector<Dtype> supported_types = {int8, int32, float16};
|
||||
std::vector<Dtype> supported_types = {int8, int32, float16, float32};
|
||||
for (auto t : supported_types) {
|
||||
auto arr = astype(arange(5), t);
|
||||
dict to_save = {{"test", arr}};
|
||||
save_gguf(file_path, to_save);
|
||||
auto loaded = load_gguf(file_path);
|
||||
CHECK(array_equal(loaded.at("test"), arr).item<bool>());
|
||||
save_gguf(file_path, to_save, original_metadata);
|
||||
const auto& [loaded_weights, loaded_metadata] = load_gguf(file_path);
|
||||
CHECK(array_equal(loaded_weights.at("test"), arr).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test gguf metadata") {
|
||||
std::string file_path = get_temp_file("test_arr.gguf");
|
||||
using dict = std::unordered_map<std::string, array>;
|
||||
dict original_weights = {
|
||||
{"test", array({1.0f, 2.0f, 3.0f, 4.0f})},
|
||||
{"test2", reshape(arange(6), {3, 2})}};
|
||||
|
||||
// Scalar array
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
original_metadata.insert({"test_arr", array(1.0)});
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
|
||||
auto [loaded_weights, loaded_metadata] = load_gguf(file_path);
|
||||
CHECK_EQ(loaded_metadata.size(), 1);
|
||||
CHECK_EQ(loaded_metadata.count("test_arr"), 1);
|
||||
|
||||
auto arr = std::get<array>(loaded_metadata.at("test_arr"));
|
||||
CHECK_EQ(arr.item<float>(), 1.0f);
|
||||
}
|
||||
|
||||
// 1D Array
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
auto arr = array({1.0, 2.0});
|
||||
original_metadata.insert({"test_arr", arr});
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
|
||||
auto [loaded_weights, loaded_metadata] = load_gguf(file_path);
|
||||
CHECK_EQ(loaded_metadata.size(), 1);
|
||||
CHECK_EQ(loaded_metadata.count("test_arr"), 1);
|
||||
|
||||
auto loaded_arr = std::get<array>(loaded_metadata.at("test_arr"));
|
||||
CHECK(array_equal(arr, loaded_arr).item<bool>());
|
||||
|
||||
// Preserves dims
|
||||
arr = array({1.0});
|
||||
original_metadata["test_arr"] = arr;
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
|
||||
std::tie(loaded_weights, loaded_metadata) = load_gguf(file_path);
|
||||
CHECK_EQ(loaded_metadata.size(), 1);
|
||||
CHECK_EQ(loaded_metadata.count("test_arr"), 1);
|
||||
|
||||
loaded_arr = std::get<array>(loaded_metadata.at("test_arr"));
|
||||
CHECK(array_equal(arr, loaded_arr).item<bool>());
|
||||
}
|
||||
|
||||
// > 1D array throws
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
original_metadata.insert({"test_arr", array({1.0}, {1, 1})});
|
||||
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
|
||||
}
|
||||
|
||||
// empty array throws
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
original_metadata.insert({"test_arr", array({})});
|
||||
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata));
|
||||
}
|
||||
|
||||
// vector of string
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
std::vector<std::string> data = {"data1", "data2", "data1234"};
|
||||
original_metadata.insert({"meta", data});
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
|
||||
auto [loaded_weights, loaded_metadata] = load_gguf(file_path);
|
||||
CHECK_EQ(loaded_metadata.size(), 1);
|
||||
CHECK_EQ(loaded_metadata.count("meta"), 1);
|
||||
auto& strs = std::get<std::vector<std::string>>(loaded_metadata["meta"]);
|
||||
CHECK_EQ(strs.size(), 3);
|
||||
for (int i = 0; i < strs.size(); ++i) {
|
||||
CHECK_EQ(strs[i], data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// vector of string, string, scalar, and array
|
||||
{
|
||||
std::unordered_map<std::string, MetaData> original_metadata;
|
||||
std::vector<std::string> data = {"data1", "data2", "data1234"};
|
||||
original_metadata.insert({"meta1", data});
|
||||
original_metadata.insert({"meta2", array(2.5)});
|
||||
original_metadata.insert({"meta3", array({1, 2, 3})});
|
||||
original_metadata.insert({"meta4", "last"});
|
||||
save_gguf(file_path, original_weights, original_metadata);
|
||||
|
||||
auto [loaded_weights, loaded_metadata] = load_gguf(file_path);
|
||||
CHECK_EQ(loaded_metadata.size(), 4);
|
||||
auto& strs = std::get<std::vector<std::string>>(loaded_metadata["meta1"]);
|
||||
CHECK_EQ(strs.size(), 3);
|
||||
for (int i = 0; i < strs.size(); ++i) {
|
||||
CHECK_EQ(strs[i], data[i]);
|
||||
}
|
||||
auto& arr = std::get<array>(loaded_metadata["meta2"]);
|
||||
CHECK_EQ(arr.item<float>(), 2.5);
|
||||
|
||||
arr = std::get<array>(loaded_metadata["meta3"]);
|
||||
CHECK(array_equal(arr, array({1, 2, 3})).item<bool>());
|
||||
|
||||
auto& str = std::get<std::string>(loaded_metadata["meta4"]);
|
||||
CHECK_EQ(str, "last");
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user