diff --git a/mlx/io.h b/mlx/io.h new file mode 100644 index 000000000..c58e1959e --- /dev/null +++ b/mlx/io.h @@ -0,0 +1,55 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/io/load.h" +#include "mlx/ops.h" +#include "mlx/stream.h" + +namespace mlx::core { + +/** Save array to out stream in .npy format */ +void save(std::shared_ptr out_stream, array a); + +/** Save array to file in .npy format */ +void save(const std::string& file, array a); + +/** Load array from reader in .npy format */ +array load(std::shared_ptr in_stream, StreamOrDevice s = {}); + +/** Load array from file in .npy format */ +array load(const std::string& file, StreamOrDevice s = {}); + +/** Load array map from .safetensors file format */ +std::unordered_map load_safetensors( + std::shared_ptr in_stream, + StreamOrDevice s = {}); +std::unordered_map load_safetensors( + const std::string& file, + StreamOrDevice s = {}); + +void save_safetensors( + std::shared_ptr in_stream, + std::unordered_map); +void save_safetensors( + const std::string& file, + std::unordered_map); + +using MetaData = + std::variant>; + +/** Load array map and metadata from .gguf file format */ +std::pair< + std::unordered_map, + std::unordered_map> +load_gguf(const std::string& file, StreamOrDevice s = {}); + +void save_gguf( + std::string file, + std::unordered_map array_map, + std::unordered_map meta_data = {}); + +} // namespace mlx::core diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 231572145..8f3c6871f 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -1,9 +1,12 @@ // Copyright © 2023 Apple Inc. +#include #include +#include -#include "mlx/ops.h" +#include "mlx/io.h" #include "mlx/primitives.h" +#include "mlx/transforms.h" #include "mlx/utils.h" extern "C" { @@ -12,6 +15,9 @@ extern "C" { namespace mlx::core { +// https://github.com/antirez/gguf-tools/blob/af7d88d808a7608a33723fba067036202910acb3/gguflib.h#L102-L108 +constexpr int gguf_array_header_size = 12; + std::optional dtype_to_gguf_tensor_type(const Dtype& dtype) { switch (dtype) { case float32: @@ -46,7 +52,7 @@ std::optional gguf_type_to_dtype(const uint32_t& gguf_type) { } } -std::tuple extract_tensor_data(gguf_tensor* tensor) { +std::pair extract_tensor_data(gguf_tensor* tensor) { std::optional equivalent_dtype = gguf_type_to_dtype(tensor->type); // If there's an equivalent type, we can simply copy. if (equivalent_dtype.has_value()) { @@ -70,15 +76,132 @@ std::tuple extract_tensor_data(gguf_tensor* tensor) { return {buffer, float16}; } -std::unordered_map load_gguf( - const std::string& file, - StreamOrDevice s) { - std::unordered_map result; - gguf_ctx* ctx = gguf_open(file.c_str()); - if (!ctx) { - throw std::runtime_error("[load_gguf] gguf_init failed"); +void set_mx_value_from_gguf( + gguf_ctx* ctx, + uint32_t type, + gguf_value* val, + MetaData& value) { + switch (type) { + case GGUF_VALUE_TYPE_UINT8: + value = array(val->uint8, uint8); + break; + case GGUF_VALUE_TYPE_INT8: + value = array(val->int8, int8); + break; + case GGUF_VALUE_TYPE_UINT16: + value = array(val->uint16, uint16); + break; + case GGUF_VALUE_TYPE_INT16: + value = array(val->int16, int16); + break; + case GGUF_VALUE_TYPE_UINT32: + value = array(val->uint32, uint32); + break; + case GGUF_VALUE_TYPE_INT32: + value = array(val->int32, int32); + break; + case GGUF_VALUE_TYPE_UINT64: + value = array(val->uint64, uint64); + break; + case GGUF_VALUE_TYPE_INT64: + value = array(val->int64, int64); + break; + case GGUF_VALUE_TYPE_FLOAT32: + value = array(val->float32, float32); + break; + case GGUF_VALUE_TYPE_BOOL: + value = array(val->boolval, bool_); + break; + case GGUF_VALUE_TYPE_STRING: + value = + std::string(val->string.string, static_cast(val->string.len)); + break; + case GGUF_VALUE_TYPE_FLOAT64: + value = array(val->float64, float32); + break; + case GGUF_VALUE_TYPE_ARRAY: { + ctx->off += gguf_array_header_size; // Skip header + char* data = reinterpret_cast(val) + gguf_array_header_size; + auto size = static_cast(val->array.len); + if (val->array.type == GGUF_VALUE_TYPE_ARRAY) { + throw std::invalid_argument( + "[load_gguf] Only supports loading 1-layer of nested arrays."); + } + switch (val->array.type) { + case GGUF_VALUE_TYPE_UINT8: + value = array(reinterpret_cast(data), {size}, uint8); + break; + case GGUF_VALUE_TYPE_INT8: + value = array(reinterpret_cast(data), {size}, int8); + break; + case GGUF_VALUE_TYPE_UINT16: + value = array(reinterpret_cast(data), {size}, uint16); + break; + case GGUF_VALUE_TYPE_INT16: + value = array(reinterpret_cast(data), {size}, int16); + break; + case GGUF_VALUE_TYPE_UINT32: + value = array(reinterpret_cast(data), {size}, uint32); + break; + case GGUF_VALUE_TYPE_INT32: + value = array(reinterpret_cast(data), {size}, int32); + break; + case GGUF_VALUE_TYPE_UINT64: + value = array(reinterpret_cast(data), {size}, uint64); + break; + case GGUF_VALUE_TYPE_INT64: + value = array(reinterpret_cast(data), {size}, int64); + break; + case GGUF_VALUE_TYPE_FLOAT32: + value = array(reinterpret_cast(data), {size}, float32); + break; + case GGUF_VALUE_TYPE_BOOL: + value = array(reinterpret_cast(data), {size}, bool_); + break; + case GGUF_VALUE_TYPE_STRING: { + std::vector strs(size); + for (auto& str : strs) { + auto str_val = reinterpret_cast(data); + data += (str_val->len + sizeof(gguf_string)); + str = std::string(str_val->string, static_cast(str_val->len)); + ctx->off += (str_val->len + sizeof(gguf_string)); + } + value = std::move(strs); + break; + } + case GGUF_VALUE_TYPE_FLOAT64: + value = array(reinterpret_cast(data), {size}, float32); + break; + default: + throw std::runtime_error( + "[load_gguf] Multiple levels of nested arrays are not supported."); + } + break; + } + default: + throw std::runtime_error("[load_gguf] Received unexpected type."); + break; } - gguf_skip_key_values_section(ctx); + if (type == GGUF_VALUE_TYPE_STRING) { + ctx->off += (sizeof(gguf_string) + std::get(value).size()); + } else if (auto pv = std::get_if(&value); pv) { + ctx->off += pv->nbytes(); + } +} + +std::unordered_map load_metadata(gguf_ctx* ctx) { + std::unordered_map metadata; + gguf_key key; + while (gguf_get_key(ctx, &key)) { + std::string key_name = std::string(key.name, key.namelen); + auto& val = metadata.insert({key_name, MetaData{}}).first->second; + set_mx_value_from_gguf(ctx, key.type, key.val, val); + } + return metadata; +} + +std::unordered_map load_arrays(gguf_ctx* ctx) { + std::unordered_map array_map; gguf_tensor tensor; while (gguf_get_tensor(ctx, &tensor)) { std::vector shape; @@ -89,27 +212,181 @@ std::unordered_map load_gguf( const auto& [data, dtype] = extract_tensor_data(&tensor); array loaded_array = array(data, shape, dtype); std::string name = std::string(tensor.name, tensor.namelen); - result.insert({name, loaded_array}); + array_map.insert({name, loaded_array}); } - gguf_close(ctx); - return result; + return array_map; } -void save_gguf(std::string file, std::unordered_map a) { +std::pair< + std::unordered_map, + std::unordered_map> +load_gguf(const std::string& file, StreamOrDevice s) { + gguf_ctx* ctx = gguf_open(file.c_str()); + if (!ctx) { + throw std::runtime_error("[load_gguf] gguf_init failed"); + } + auto metadata = load_metadata(ctx); + auto arrays = load_arrays(ctx); + gguf_close(ctx); + return {arrays, metadata}; +} + +void append_kv_array( + gguf_ctx* ctx, + const std::string& key, + array& val, + uint32_t gguf_type) { + if (val.ndim() == 1) { + size_t gguf_size = val.nbytes() + gguf_array_header_size; + std::vector val_vec(gguf_size); + gguf_value* gguf_val = reinterpret_cast(val_vec.data()); + gguf_val->array.type = gguf_type; + gguf_val->array.len = val.size(); + memcpy( + val_vec.data() + gguf_array_header_size, + val.data(), + val.nbytes()); + gguf_append_kv( + ctx, + key.c_str(), + key.length(), + GGUF_VALUE_TYPE_ARRAY, + reinterpret_cast(val_vec.data()), + gguf_size); + } else { + gguf_append_kv( + ctx, + key.c_str(), + key.length(), + gguf_type, + reinterpret_cast(val.data()), + val.nbytes()); + } +} + +void save_gguf( + std::string file, + std::unordered_map array_map, + std::unordered_map metadata /* = {} */) { // Add .gguf to file name if it is not there if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") { file += ".gguf"; } + gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE); if (!ctx) { throw std::runtime_error("[save_gguf] gguf_create failed"); } + auto string_to_gguf = [](char* dst, const std::string& src) { + gguf_string* val = reinterpret_cast(dst); + val->len = src.length(); + memcpy(val->string, src.c_str(), src.length()); + }; + + // Save any meta data + for (auto& [key, value] : metadata) { + if (auto pv = std::get_if(&value); pv) { + const std::string& str = *pv; + size_t size = sizeof(gguf_string) + str.length(); + std::vector val_vec(size); + string_to_gguf(val_vec.data(), str); + gguf_append_kv( + ctx, + key.c_str(), + key.length(), + GGUF_VALUE_TYPE_STRING, + static_cast(val_vec.data()), + size); + } else if (auto pv = std::get_if>(&value); pv) { + const auto& str_vec = *pv; + auto mem_size = std::accumulate( + str_vec.begin(), str_vec.end(), 0, [](size_t accum, const auto& s) { + return accum + s.size(); + }); + mem_size += str_vec.size() * sizeof(gguf_string) + gguf_array_header_size; + std::vector val_vec(mem_size); + gguf_value* val = reinterpret_cast(val_vec.data()); + val->array.type = GGUF_VALUE_TYPE_STRING; + val->array.len = str_vec.size(); + auto str_ptr = val_vec.data() + gguf_array_header_size; + for (auto& str : str_vec) { + string_to_gguf(str_ptr, str); + str_ptr += str.length() + sizeof(gguf_string); + } + gguf_append_kv( + ctx, + key.c_str(), + key.length(), + GGUF_VALUE_TYPE_ARRAY, + static_cast(val), + mem_size); + } else if (auto pv = std::get_if(&value); pv) { + array v = *pv; + if (v.ndim() > 1) { + throw std::runtime_error( + "[save_gguf] Cannot save arrays with more than one dimension."); + } + if (v.size() == 0) { + throw std::runtime_error("[save_gguf] Cannot save empty arrays."); + } + + eval(v); + if (!v.flags().row_contiguous) { + v = reshape(flatten(v), v.shape()); + } + if (!v.flags().row_contiguous) { + throw std::runtime_error( + "[save_gguf] Cannot save non contiguous arrays."); + } + switch (v.dtype()) { + case float32: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_FLOAT32); + break; + case int64: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT64); + break; + case int32: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT32); + break; + case int16: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT16); + break; + case int8: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT8); + break; + case uint64: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT64); + break; + case uint32: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT32); + break; + case uint16: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT16); + break; + case uint8: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT8); + break; + case bool_: + append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_BOOL); + break; + default: + std::ostringstream msg; + msg << "[save_gguf] array type " << v.dtype() + << " not support for metadata."; + throw std::invalid_argument(msg.str()); + } + } else { + throw std::runtime_error( + "[save_gguf] Received unexpected type in metadata"); + } + } + // Tensor offsets are relative to data section, so we start at offset 0. uint64_t tensor_offset = 0; // First, append the tensor info - for (auto& [key, arr] : a) { + for (auto& [key, arr] : array_map) { arr.eval(); // Try to make it row contiguous @@ -154,7 +431,7 @@ void save_gguf(std::string file, std::unordered_map a) { } // Then, append the tensor weights - for (const auto& [key, arr] : a) { + for (const auto& [key, arr] : array_map) { if (!gguf_append_tensor_data(ctx, (void*)arr.data(), arr.nbytes())) { throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed"); } diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index 406169312..7e7868d49 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -3,8 +3,8 @@ #include #include +#include "mlx/io.h" #include "mlx/io/load.h" -#include "mlx/ops.h" #include "mlx/primitives.h" using json = nlohmann::json; diff --git a/mlx/mlx.h b/mlx/mlx.h index 8d785c39f..a0a281974 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -6,6 +6,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/device.h" #include "mlx/fft.h" +#include "mlx/io.h" #include "mlx/linalg.h" #include "mlx/ops.h" #include "mlx/random.h" diff --git a/mlx/ops.h b/mlx/ops.h index 865617cd9..1f2ace43f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1,14 +1,13 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once #include #include -#include "array.h" -#include "device.h" -#include "io/load.h" -#include "stream.h" +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/stream.h" namespace mlx::core { @@ -1040,20 +1039,6 @@ array conv2d( int groups = 1, StreamOrDevice s = {}); -/** Serialization operations */ - -/** Save array to out stream in .npy format */ -void save(std::shared_ptr out_stream, array a); - -/** Save array to file in .npy format */ -void save(const std::string& file, array a); - -/** Load array from reader in .npy format */ -array load(std::shared_ptr in_stream, StreamOrDevice s = {}); - -/** Load array from file in .npy format */ -array load(const std::string& file, StreamOrDevice s = {}); - /** Quantized matmul multiplies x with a quantized matrix w*/ array quantized_matmul( const array& x, @@ -1100,28 +1085,6 @@ array outer(const array& a, const array& b, StreamOrDevice s = {}); /** Compute the inner product of two vectors. */ array inner(const array& a, const array& b, StreamOrDevice s = {}); -/** Load array map from .safetensors file format */ -std::unordered_map load_safetensors( - std::shared_ptr in_stream, - StreamOrDevice s = {}); -std::unordered_map load_safetensors( - const std::string& file, - StreamOrDevice s = {}); - -void save_safetensors( - std::shared_ptr in_stream, - std::unordered_map); -void save_safetensors( - const std::string& file, - std::unordered_map); - -/** Load array map from .gguf file format */ -std::unordered_map load_gguf( - const std::string& file, - StreamOrDevice s = {}); - -void save_gguf(std::string file, std::unordered_map a); - /** Compute D = beta * C + alpha * (A @ B) */ array addmm( array c, @@ -1130,4 +1093,5 @@ array addmm( const float& alpha = 1.f, const float& beta = 1.f, StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/python/src/load.cpp b/python/src/load.cpp index 03b108d8e..92ad5e808 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -181,9 +181,10 @@ std::unordered_map mlx_load_safetensor_helper( "[load_safetensors] Input must be a file-like object, or string"); } -std::unordered_map mlx_load_gguf_helper( - py::object file, - StreamOrDevice s) { +std::pair< + std::unordered_map, + std::unordered_map> +mlx_load_gguf_helper(py::object file, StreamOrDevice s) { if (py::isinstance(file)) { // Assume .gguf file path string return load_gguf(py::cast(file), s); } @@ -246,9 +247,10 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) { "[load_npy] Input must be a file-like object, or string"); } -DictOrArray mlx_load_helper( +LoadOutputTypes mlx_load_helper( py::object file, std::optional format, + bool return_metadata, StreamOrDevice s) { if (!format.has_value()) { std::string fname; @@ -268,6 +270,10 @@ DictOrArray mlx_load_helper( format.emplace(fname.substr(ext + 1)); } + if (return_metadata && format.value() != "gguf") { + throw std::invalid_argument( + "[load] metadata not supported for format " + format.value()); + } if (format.value() == "safetensors") { return mlx_load_safetensor_helper(file, s); } else if (format.value() == "npz") { @@ -275,7 +281,12 @@ DictOrArray mlx_load_helper( } else if (format.value() == "npy") { return mlx_load_npy_helper(file, s); } else if (format.value() == "gguf") { - return mlx_load_gguf_helper(file, s); + auto [weights, metadata] = mlx_load_gguf_helper(file, s); + if (return_metadata) { + return std::make_pair(weights, metadata); + } else { + return weights; + } } else { throw std::invalid_argument("[load] Unknown file format " + format.value()); } @@ -448,10 +459,19 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) { "[save_safetensors] Input must be a file-like object, or string"); } -void mlx_save_gguf_helper(py::object file, py::dict d) { - auto arrays_map = d.cast>(); +void mlx_save_gguf_helper( + py::object file, + py::dict a, + std::optional m) { + auto arrays_map = a.cast>(); if (py::isinstance(file)) { - save_gguf(py::cast(file), arrays_map); + if (m) { + auto metadata_map = + m.value().cast>(); + save_gguf(py::cast(file), arrays_map, metadata_map); + } else { + save_gguf(py::cast(file), arrays_map); + } return; } diff --git a/python/src/load.h b/python/src/load.h index 19feb5f5a..dbe0f9cd6 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -7,26 +7,36 @@ #include #include #include -#include "mlx/ops.h" +#include "mlx/io.h" namespace py = pybind11; using namespace mlx::core; -using DictOrArray = std::variant>; +using LoadOutputTypes = std::variant< + array, + std::unordered_map, + std::pair< + std::unordered_map, + std::unordered_map>>; std::unordered_map mlx_load_safetensor_helper( py::object file, StreamOrDevice s); void mlx_save_safetensor_helper(py::object file, py::dict d); -std::unordered_map mlx_load_gguf_helper( +std::pair< + std::unordered_map, + std::unordered_map> +mlx_load_gguf_helper(py::object file, StreamOrDevice s); +void mlx_save_gguf_helper( py::object file, - StreamOrDevice s); -void mlx_save_gguf_helper(py::object file, py::dict d); + py::dict d, + std::optional m); -DictOrArray mlx_load_helper( +LoadOutputTypes mlx_load_helper( py::object file, std::optional format, + bool return_metadata, StreamOrDevice s); void mlx_save_helper(py::object file, array a); void mlx_savez_helper( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a85c94d86..879090e95 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1867,11 +1867,11 @@ void init_ops(py::module_& m) { isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array Return a boolean array indicating which elements are positive infinity. - + Args: a (array): Input array. stream (Union[None, Stream, Device]): Optional stream or device. - + Returns: array: The boolean array indicating which elements are positive infinity. )pbdoc"); @@ -1886,11 +1886,11 @@ void init_ops(py::module_& m) { isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array Return a boolean array indicating which elements are negative infinity. - + Args: a (array): Input array. stream (Union[None, Stream, Device]): Optional stream or device. - + Returns: array: The boolean array indicating which elements are negative infinity. )pbdoc"); @@ -3117,10 +3117,11 @@ void init_ops(py::module_& m) { "file"_a, py::pos_only(), "format"_a = none, + "return_metadata"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( - load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] + load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] Load array(s) from a binary file. @@ -3131,10 +3132,15 @@ void init_ops(py::module_& m) { format (str, optional): Format of the file. If ``None``, the format is inferred from the file extension. Supported formats: ``npy``, ``npz``, and ``safetensors``. Default: ``None``. + return_metadata (bool, optional): Load the metadata for formats which + support matadata. The metadata will be returned as an additional + dictionary. Returns: result (array, dict): A single array if loading from a ``.npy`` file or a dict mapping names to arrays if loading from a ``.npz`` or ``.safetensors`` file. + If ``return_metadata` is ``True`` an additional dictionary of metadata + will be returned. Warning: @@ -3164,8 +3170,9 @@ void init_ops(py::module_& m) { &mlx_save_gguf_helper, "file"_a, "arrays"_a, + "metadata"_a = none, R"pbdoc( - save_gguf(file: str, arrays: Dict[str, array]) + save_gguf(file: str, arrays: Dict[str, array], metadata: Dict[str, Union[array, str, List[str]]]) Save array(s) to a binary file in ``.gguf`` format. @@ -3175,6 +3182,9 @@ void init_ops(py::module_& m) { Args: file (file, str): File in which the array is saved. arrays (dict(str, array)): The dictionary of names to arrays to be saved. + metadata (dict(str, Union[array, str, list(str)])): The dictionary of + metadata to be saved. The values can be a scalar or 1D obj:`array`, + a :obj:`str`, or a :obj:`list` of :obj:`str`. )pbdoc"); m.def( "where", @@ -3499,7 +3509,7 @@ void init_ops(py::module_& m) { c (array): Input array or scalar. a (array): Input array or scalar. b (array): Input array or scalar. - alpha (float, optional): Scaling factor for the + alpha (float, optional): Scaling factor for the matrix product of ``a`` and ``b`` (default: ``1``) beta (float, optional): Scaling factor for ``c`` (default: ``1``) diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index fe2346fea..d1ac0a3a1 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -576,8 +576,8 @@ class TestBlas(mlx_tests.MLXTestCase): ], ) - self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item()) + self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) for r, t in zip(dout_ref, dout_test): self.assertListEqual(r.shape, t.shape) - self.assertTrue(mx.allclose(r, t, atol=1e-5).item()) + self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 3b7baba54..a37ba83a9 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -117,6 +117,115 @@ class TestLoad(mlx_tests.MLXTestCase): mx.array_equal(load_dict["test"], save_dict["test"]) ) + def test_save_and_load_gguf_metadata_basic(self): + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf") + save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)} + metadata = {} + + # Empty works + mx.save_gguf(save_file_mlx, save_dict, metadata) + + # Loads without the metadata + load_dict = mx.load(save_file_mlx) + self.assertTrue("test" in load_dict) + self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) + + # Loads empty metadata + load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) + self.assertTrue("test" in load_dict) + self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) + self.assertEqual(len(meta_load_dict), 0) + + # Loads string metadata + metadata = {"meta": "data"} + mx.save_gguf(save_file_mlx, save_dict, metadata) + load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) + self.assertTrue("test" in load_dict) + self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"])) + self.assertEqual(len(meta_load_dict), 1) + self.assertTrue("meta" in meta_load_dict) + self.assertEqual(meta_load_dict["meta"], "data") + + def test_save_and_load_gguf_metadata_arrays(self): + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf") + save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)} + + # Test scalars and one dimensional arrays + for t in [ + mx.uint8, + mx.int8, + mx.uint16, + mx.int16, + mx.uint32, + mx.int32, + mx.uint64, + mx.int64, + mx.float32, + ]: + for shape in [(), (2,)]: + arr = mx.random.uniform(shape=shape).astype(t) + metadata = {"meta": arr} + mx.save_gguf(save_file_mlx, save_dict, metadata) + _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) + self.assertEqual(len(meta_load_dict), 1) + self.assertTrue("meta" in meta_load_dict) + self.assertTrue(mx.array_equal(meta_load_dict["meta"], arr)) + self.assertEqual(meta_load_dict["meta"].dtype, arr.dtype) + + for t in [mx.float16, mx.bfloat16, mx.complex64]: + with self.assertRaises(ValueError): + arr = mx.array(1, t) + metadata = {"meta": arr} + mx.save_gguf(save_file_mlx, save_dict, metadata) + + def test_save_and_load_gguf_metadata_mixed(self): + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf") + save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)} + + # Test string and array + arr = mx.array(1.5) + metadata = {"meta1": arr, "meta2": "data"} + mx.save_gguf(save_file_mlx, save_dict, metadata) + _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) + self.assertEqual(len(meta_load_dict), 2) + self.assertTrue("meta1" in meta_load_dict) + self.assertTrue(mx.array_equal(meta_load_dict["meta1"], arr)) + self.assertEqual(meta_load_dict["meta1"].dtype, arr.dtype) + self.assertTrue("meta2" in meta_load_dict) + self.assertEqual(meta_load_dict["meta2"], "data") + + # Test list of strings + metadata = {"meta": ["data1", "data2", "data345"]} + mx.save_gguf(save_file_mlx, save_dict, metadata) + _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) + self.assertEqual(len(meta_load_dict), 1) + self.assertEqual(meta_load_dict["meta"], metadata["meta"]) + + # Test a combination of stuff + metadata = { + "meta1": ["data1", "data2", "data345"], + "meta2": mx.array([1, 2, 3, 4]), + "meta3": "data", + "meta4": mx.array(1.5), + } + mx.save_gguf(save_file_mlx, save_dict, metadata) + _, meta_load_dict = mx.load(save_file_mlx, return_metadata=True) + self.assertEqual(len(meta_load_dict), 4) + for k, v in metadata.items(): + if isinstance(v, mx.array): + self.assertTrue(mx.array_equal(meta_load_dict[k], v)) + else: + self.assertEqual(meta_load_dict[k], v) + def test_save_and_load_fs(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir) diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 8b77a2eb3..51d1659f3 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -37,33 +37,161 @@ TEST_CASE("test save_safetensors") { TEST_CASE("test gguf") { std::string file_path = get_temp_file("test_arr.gguf"); using dict = std::unordered_map; - dict map = { + dict original_weights = { {"test", array({1.0f, 2.0f, 3.0f, 4.0f})}, {"test2", reshape(arange(6), {3, 2})}}; - save_gguf(file_path, map); - auto loaded = load_gguf(file_path); - CHECK_EQ(loaded.size(), 2); - CHECK_EQ(loaded.count("test"), 1); - CHECK_EQ(loaded.count("test2"), 1); - for (auto [k, v] : loaded) { - CHECK(array_equal(v, map.at(k)).item()); + { + // Check saving loading just arrays, no metadata + save_gguf(file_path, original_weights); + auto [loaded_weights, loaded_metadata] = load_gguf(file_path); + CHECK_EQ(loaded_metadata.size(), 0); + CHECK_EQ(loaded_weights.size(), 2); + CHECK_EQ(loaded_weights.count("test"), 1); + CHECK_EQ(loaded_weights.count("test2"), 1); + for (auto [k, v] : loaded_weights) { + CHECK(array_equal(v, original_weights.at(k)).item()); + } + } + + // Test saving and loading string metadata + std::unordered_map original_metadata; + original_metadata.insert({"test_str", "my string"}); + + save_gguf(file_path, original_weights, original_metadata); + auto [loaded_weights, loaded_metadata] = load_gguf(file_path); + CHECK_EQ(loaded_metadata.size(), 1); + CHECK_EQ(loaded_metadata.count("test_str"), 1); + CHECK_EQ(std::get(loaded_metadata.at("test_str")), "my string"); + + CHECK_EQ(loaded_weights.size(), 2); + CHECK_EQ(loaded_weights.count("test"), 1); + CHECK_EQ(loaded_weights.count("test2"), 1); + for (auto [k, v] : loaded_weights) { + CHECK(array_equal(v, original_weights.at(k)).item()); } std::vector unsupported_types = { bool_, uint8, uint32, uint64, int64, bfloat16, complex64}; for (auto t : unsupported_types) { dict to_save = {{"test", astype(arange(5), t)}}; - CHECK_THROWS(save_gguf(file_path, to_save)); + CHECK_THROWS(save_gguf(file_path, to_save, original_metadata)); } - std::vector supported_types = {int8, int32, float16}; + std::vector supported_types = {int8, int32, float16, float32}; for (auto t : supported_types) { auto arr = astype(arange(5), t); dict to_save = {{"test", arr}}; - save_gguf(file_path, to_save); - auto loaded = load_gguf(file_path); - CHECK(array_equal(loaded.at("test"), arr).item()); + save_gguf(file_path, to_save, original_metadata); + const auto& [loaded_weights, loaded_metadata] = load_gguf(file_path); + CHECK(array_equal(loaded_weights.at("test"), arr).item()); + } +} + +TEST_CASE("test gguf metadata") { + std::string file_path = get_temp_file("test_arr.gguf"); + using dict = std::unordered_map; + dict original_weights = { + {"test", array({1.0f, 2.0f, 3.0f, 4.0f})}, + {"test2", reshape(arange(6), {3, 2})}}; + + // Scalar array + { + std::unordered_map original_metadata; + original_metadata.insert({"test_arr", array(1.0)}); + save_gguf(file_path, original_weights, original_metadata); + + auto [loaded_weights, loaded_metadata] = load_gguf(file_path); + CHECK_EQ(loaded_metadata.size(), 1); + CHECK_EQ(loaded_metadata.count("test_arr"), 1); + + auto arr = std::get(loaded_metadata.at("test_arr")); + CHECK_EQ(arr.item(), 1.0f); + } + + // 1D Array + { + std::unordered_map original_metadata; + auto arr = array({1.0, 2.0}); + original_metadata.insert({"test_arr", arr}); + save_gguf(file_path, original_weights, original_metadata); + + auto [loaded_weights, loaded_metadata] = load_gguf(file_path); + CHECK_EQ(loaded_metadata.size(), 1); + CHECK_EQ(loaded_metadata.count("test_arr"), 1); + + auto loaded_arr = std::get(loaded_metadata.at("test_arr")); + CHECK(array_equal(arr, loaded_arr).item()); + + // Preserves dims + arr = array({1.0}); + original_metadata["test_arr"] = arr; + save_gguf(file_path, original_weights, original_metadata); + + std::tie(loaded_weights, loaded_metadata) = load_gguf(file_path); + CHECK_EQ(loaded_metadata.size(), 1); + CHECK_EQ(loaded_metadata.count("test_arr"), 1); + + loaded_arr = std::get(loaded_metadata.at("test_arr")); + CHECK(array_equal(arr, loaded_arr).item()); + } + + // > 1D array throws + { + std::unordered_map original_metadata; + original_metadata.insert({"test_arr", array({1.0}, {1, 1})}); + CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); + } + + // empty array throws + { + std::unordered_map original_metadata; + original_metadata.insert({"test_arr", array({})}); + CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); + } + + // vector of string + { + std::unordered_map original_metadata; + std::vector data = {"data1", "data2", "data1234"}; + original_metadata.insert({"meta", data}); + save_gguf(file_path, original_weights, original_metadata); + + auto [loaded_weights, loaded_metadata] = load_gguf(file_path); + CHECK_EQ(loaded_metadata.size(), 1); + CHECK_EQ(loaded_metadata.count("meta"), 1); + auto& strs = std::get>(loaded_metadata["meta"]); + CHECK_EQ(strs.size(), 3); + for (int i = 0; i < strs.size(); ++i) { + CHECK_EQ(strs[i], data[i]); + } + } + + // vector of string, string, scalar, and array + { + std::unordered_map original_metadata; + std::vector data = {"data1", "data2", "data1234"}; + original_metadata.insert({"meta1", data}); + original_metadata.insert({"meta2", array(2.5)}); + original_metadata.insert({"meta3", array({1, 2, 3})}); + original_metadata.insert({"meta4", "last"}); + save_gguf(file_path, original_weights, original_metadata); + + auto [loaded_weights, loaded_metadata] = load_gguf(file_path); + CHECK_EQ(loaded_metadata.size(), 4); + auto& strs = std::get>(loaded_metadata["meta1"]); + CHECK_EQ(strs.size(), 3); + for (int i = 0; i < strs.size(); ++i) { + CHECK_EQ(strs[i], data[i]); + } + auto& arr = std::get(loaded_metadata["meta2"]); + CHECK_EQ(arr.item(), 2.5); + + arr = std::get(loaded_metadata["meta3"]); + CHECK(array_equal(arr, array({1, 2, 3})).item()); + + auto& str = std::get(loaded_metadata["meta4"]); + CHECK_EQ(str, "last"); } }