mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 03:22:54 +08:00
parent
6b0d30bb85
commit
1f6ab6a556
4
.gitignore
vendored
4
.gitignore
vendored
@ -6,6 +6,10 @@ __pycache__/
|
|||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
# tensor files
|
||||||
|
*.safe
|
||||||
|
*.safetensors
|
||||||
|
|
||||||
# Metal libraries
|
# Metal libraries
|
||||||
*.metallib
|
*.metallib
|
||||||
venv/
|
venv/
|
||||||
|
@ -98,6 +98,15 @@ elseif (MLX_BUILD_METAL)
|
|||||||
${QUARTZ_LIB})
|
${QUARTZ_LIB})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
MESSAGE(STATUS "Downloading json")
|
||||||
|
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||||
|
FetchContent_MakeAvailable(json)
|
||||||
|
target_include_directories(
|
||||||
|
mlx PUBLIC
|
||||||
|
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
|
||||||
|
$<INSTALL_INTERFACE:include/json>
|
||||||
|
)
|
||||||
|
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
@ -152,6 +161,8 @@ if (MLX_BUILD_BENCHMARKS)
|
|||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------- Installation -----------------------------
|
# ----------------------------- Installation -----------------------------
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
|
@ -83,6 +83,7 @@ Operations
|
|||||||
save
|
save
|
||||||
savez
|
savez
|
||||||
savez_compressed
|
savez_compressed
|
||||||
|
save_safetensors
|
||||||
sigmoid
|
sigmoid
|
||||||
sign
|
sign
|
||||||
sin
|
sin
|
||||||
|
@ -8,7 +8,6 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||||
@ -19,7 +18,7 @@ target_sources(
|
|||||||
)
|
)
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||||
if (MLX_BUILD_ACCELERATE)
|
if (MLX_BUILD_ACCELERATE)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||||
else()
|
else()
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/load.h"
|
#include "mlx/io/load.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
6
mlx/io/CMakeLists.txt
Normal file
6
mlx/io/CMakeLists.txt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
|
||||||
|
)
|
@ -6,7 +6,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/load.h"
|
#include "mlx/io/load.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
189
mlx/io/safetensor.cpp
Normal file
189
mlx/io/safetensor.cpp
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
#include "mlx/io/safetensor.h"
|
||||||
|
|
||||||
|
#include <stack>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::string dtype_to_safetensor_str(Dtype t) {
|
||||||
|
switch (t) {
|
||||||
|
case float32:
|
||||||
|
return ST_F32;
|
||||||
|
case bfloat16:
|
||||||
|
return ST_BF16;
|
||||||
|
case float16:
|
||||||
|
return ST_F16;
|
||||||
|
case int64:
|
||||||
|
return ST_I64;
|
||||||
|
case int32:
|
||||||
|
return ST_I32;
|
||||||
|
case int16:
|
||||||
|
return ST_I16;
|
||||||
|
case int8:
|
||||||
|
return ST_I8;
|
||||||
|
case uint64:
|
||||||
|
return ST_U64;
|
||||||
|
case uint32:
|
||||||
|
return ST_U32;
|
||||||
|
case uint16:
|
||||||
|
return ST_U16;
|
||||||
|
case uint8:
|
||||||
|
return ST_U8;
|
||||||
|
case bool_:
|
||||||
|
return ST_BOOL;
|
||||||
|
case complex64:
|
||||||
|
return ST_C64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Dtype dtype_from_safetensor_str(std::string str) {
|
||||||
|
if (str == ST_F32) {
|
||||||
|
return float32;
|
||||||
|
} else if (str == ST_F16) {
|
||||||
|
return float16;
|
||||||
|
} else if (str == ST_BF16) {
|
||||||
|
return bfloat16;
|
||||||
|
} else if (str == ST_I64) {
|
||||||
|
return int64;
|
||||||
|
} else if (str == ST_I32) {
|
||||||
|
return int32;
|
||||||
|
} else if (str == ST_I16) {
|
||||||
|
return int16;
|
||||||
|
} else if (str == ST_I8) {
|
||||||
|
return int8;
|
||||||
|
} else if (str == ST_U64) {
|
||||||
|
return uint64;
|
||||||
|
} else if (str == ST_U32) {
|
||||||
|
return uint32;
|
||||||
|
} else if (str == ST_U16) {
|
||||||
|
return uint16;
|
||||||
|
} else if (str == ST_U8) {
|
||||||
|
return uint8;
|
||||||
|
} else if (str == ST_BOOL) {
|
||||||
|
return bool_;
|
||||||
|
} else if (str == ST_C64) {
|
||||||
|
return complex64;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("[safetensor] unsupported dtype " + str);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Load array from reader in safetensor format */
|
||||||
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Open and check file
|
||||||
|
if (!in_stream->good() || !in_stream->is_open()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[load_safetensors] Failed to open " + in_stream->label());
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t jsonHeaderLength = 0;
|
||||||
|
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
|
||||||
|
if (jsonHeaderLength <= 0) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[load_safetensors] Invalid json header length " + in_stream->label());
|
||||||
|
}
|
||||||
|
// Load the json metadata
|
||||||
|
char rawJson[jsonHeaderLength];
|
||||||
|
in_stream->read(rawJson, jsonHeaderLength);
|
||||||
|
auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength);
|
||||||
|
// Should always be an object on the top-level
|
||||||
|
if (!metadata.is_object()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[load_safetensors] Invalid json metadata " + in_stream->label());
|
||||||
|
}
|
||||||
|
size_t offset = jsonHeaderLength + 8;
|
||||||
|
// Load the arrays using metadata
|
||||||
|
std::unordered_map<std::string, array> res;
|
||||||
|
for (const auto& item : metadata.items()) {
|
||||||
|
if (item.key() == "__metadata__") {
|
||||||
|
// ignore metadata for now
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string dtype = item.value().at("dtype");
|
||||||
|
std::vector<int> shape = item.value().at("shape");
|
||||||
|
std::vector<size_t> data_offsets = item.value().at("data_offsets");
|
||||||
|
Dtype type = dtype_from_safetensor_str(dtype);
|
||||||
|
auto loaded_array = array(
|
||||||
|
shape,
|
||||||
|
type,
|
||||||
|
std::make_unique<Load>(
|
||||||
|
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
||||||
|
std::vector<array>{});
|
||||||
|
res.insert({item.key(), loaded_array});
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
|
const std::string& file,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Save array to out stream in .npy format */
|
||||||
|
void save_safetensors(
|
||||||
|
std::shared_ptr<io::Writer> out_stream,
|
||||||
|
std::unordered_map<std::string, array> a,
|
||||||
|
std::optional<bool> retain_graph_) {
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Check file
|
||||||
|
if (!out_stream->good() || !out_stream->is_open()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[save_safetensors] Failed to open " + out_stream->label());
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Check array map
|
||||||
|
json parent;
|
||||||
|
parent["__metadata__"] = json::object({
|
||||||
|
{"format", "mlx"},
|
||||||
|
});
|
||||||
|
size_t offset = 0;
|
||||||
|
for (auto& [key, arr] : a) {
|
||||||
|
arr.eval(retain_graph_.value_or(arr.is_tracer()));
|
||||||
|
if (arr.nbytes() == 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!arr.flags().contiguous) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[save_safetensors] cannot serialize a non-contiguous array key: " +
|
||||||
|
key);
|
||||||
|
}
|
||||||
|
json child;
|
||||||
|
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
||||||
|
child["shape"] = arr.shape();
|
||||||
|
child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()};
|
||||||
|
parent[key] = child;
|
||||||
|
offset += arr.nbytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto header = parent.dump();
|
||||||
|
uint64_t header_len = header.length();
|
||||||
|
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
|
||||||
|
out_stream->write(header.c_str(), header_len);
|
||||||
|
for (auto& [key, arr] : a) {
|
||||||
|
out_stream->write(arr.data<char>(), arr.nbytes());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void save_safetensors(
|
||||||
|
const std::string& file_,
|
||||||
|
std::unordered_map<std::string, array> a,
|
||||||
|
std::optional<bool> retain_graph) {
|
||||||
|
// Open and check file
|
||||||
|
std::string file = file_;
|
||||||
|
|
||||||
|
// Add .safetensors to file name if it is not there
|
||||||
|
if (file.length() < 12 ||
|
||||||
|
file.substr(file.length() - 12, 12) != ".safetensors")
|
||||||
|
file += ".safetensors";
|
||||||
|
|
||||||
|
// Serialize array
|
||||||
|
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
32
mlx/io/safetensor.h
Normal file
32
mlx/io/safetensor.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <json.hpp>
|
||||||
|
|
||||||
|
#include "mlx/io/load.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
#define ST_F16 "F16"
|
||||||
|
#define ST_BF16 "BF16"
|
||||||
|
#define ST_F32 "F32"
|
||||||
|
|
||||||
|
#define ST_BOOL "BOOL"
|
||||||
|
#define ST_I8 "I8"
|
||||||
|
#define ST_I16 "I16"
|
||||||
|
#define ST_I32 "I32"
|
||||||
|
#define ST_I64 "I64"
|
||||||
|
#define ST_U8 "U8"
|
||||||
|
#define ST_U16 "U16"
|
||||||
|
#define ST_U32 "U32"
|
||||||
|
#define ST_U64 "U64"
|
||||||
|
|
||||||
|
// Note: Complex numbers aren't in the spec yet so this could change -
|
||||||
|
// https://github.com/huggingface/safetensors/issues/389
|
||||||
|
#define ST_C64 "C64"
|
||||||
|
} // namespace mlx::core
|
18
mlx/ops.h
18
mlx/ops.h
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
#include "array.h"
|
#include "array.h"
|
||||||
#include "device.h"
|
#include "device.h"
|
||||||
#include "load.h"
|
#include "io/load.h"
|
||||||
#include "stream.h"
|
#include "stream.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -1057,4 +1057,20 @@ array dequantize(
|
|||||||
int bits = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Load array map from .safetensors file format */
|
||||||
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
|
const std::string& file,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
void save_safetensors(
|
||||||
|
std::shared_ptr<io::Writer> in_stream,
|
||||||
|
std::unordered_map<std::string, array>,
|
||||||
|
std::optional<bool> retain_graph = std::nullopt);
|
||||||
|
void save_safetensors(
|
||||||
|
const std::string& file,
|
||||||
|
std::unordered_map<std::string, array>,
|
||||||
|
std::optional<bool> retain_graph = std::nullopt);
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#include "array.h"
|
#include "array.h"
|
||||||
#include "device.h"
|
#include "device.h"
|
||||||
#include "load.h"
|
#include "io/load.h"
|
||||||
#include "stream.h"
|
#include "stream.h"
|
||||||
|
|
||||||
#define DEFINE_GRADS() \
|
#define DEFINE_GRADS() \
|
||||||
|
@ -6,12 +6,11 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/load.h"
|
#include "mlx/io/load.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
#include "python/src/load.h"
|
#include "python/src/load.h"
|
||||||
@ -161,40 +160,68 @@ class PyFileReader : public io::Reader {
|
|||||||
py::object tell_func_;
|
py::object tell_func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||||
py::module_ zipfile = py::module_::import("zipfile");
|
py::object file,
|
||||||
|
StreamOrDevice s) {
|
||||||
// Assume .npz file if it is zipped
|
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
||||||
if (is_zip_file(zipfile, file)) {
|
return {load_safetensors(py::cast<std::string>(file), s)};
|
||||||
// Output dictionary filename in zip -> loaded array
|
} else if (is_istream_object(file)) {
|
||||||
std::unordered_map<std::string, array> array_dict;
|
|
||||||
|
|
||||||
// Create python ZipFile object
|
|
||||||
ZipFileWrapper zipfile_object(zipfile, file);
|
|
||||||
for (const std::string& st : zipfile_object.namelist()) {
|
|
||||||
// Open zip file as a python file stream
|
|
||||||
py::object sub_file = zipfile_object.open(st);
|
|
||||||
|
|
||||||
// Create array from python fille stream
|
|
||||||
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
|
|
||||||
|
|
||||||
// Remove .npy from file if it is there
|
|
||||||
auto key = st;
|
|
||||||
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
|
|
||||||
key = st.substr(0, st.length() - 4);
|
|
||||||
|
|
||||||
// Add array to dict
|
|
||||||
array_dict.insert({key, arr});
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
for (auto& [key, arr] : array_dict) {
|
auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||||
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
arr.eval();
|
for (auto& [key, arr] : arr) {
|
||||||
|
arr.eval();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return {arr};
|
||||||
|
}
|
||||||
|
|
||||||
return {array_dict};
|
throw std::invalid_argument(
|
||||||
} else if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
"[load_safetensors] Input must be a file-like object, or string");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||||
|
py::object file,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
py::module_ zipfile = py::module_::import("zipfile");
|
||||||
|
if (!is_zip_file(zipfile, file)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[load_npz] Input must be a zip file or a file-like object that can be "
|
||||||
|
"opened with zipfile.ZipFile");
|
||||||
|
}
|
||||||
|
// Output dictionary filename in zip -> loaded array
|
||||||
|
std::unordered_map<std::string, array> array_dict;
|
||||||
|
|
||||||
|
// Create python ZipFile object
|
||||||
|
ZipFileWrapper zipfile_object(zipfile, file);
|
||||||
|
for (const std::string& st : zipfile_object.namelist()) {
|
||||||
|
// Open zip file as a python file stream
|
||||||
|
py::object sub_file = zipfile_object.open(st);
|
||||||
|
|
||||||
|
// Create array from python fille stream
|
||||||
|
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
|
||||||
|
|
||||||
|
// Remove .npy from file if it is there
|
||||||
|
auto key = st;
|
||||||
|
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
|
||||||
|
key = st.substr(0, st.length() - 4);
|
||||||
|
|
||||||
|
// Add array to dict
|
||||||
|
array_dict.insert({key, arr});
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
|
for (auto& [key, arr] : array_dict) {
|
||||||
|
py::gil_scoped_release gil;
|
||||||
|
arr.eval();
|
||||||
|
}
|
||||||
|
|
||||||
|
return {array_dict};
|
||||||
|
}
|
||||||
|
|
||||||
|
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||||
|
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
||||||
return {load(py::cast<std::string>(file), s)};
|
return {load(py::cast<std::string>(file), s)};
|
||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
@ -205,9 +232,41 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
|
|||||||
}
|
}
|
||||||
return {arr};
|
return {arr};
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[load] Input must be a file-like object, string, or pathlib.Path");
|
"[load_npy] Input must be a file-like object, or string");
|
||||||
|
}
|
||||||
|
|
||||||
|
DictOrArray mlx_load_helper(
|
||||||
|
py::object file,
|
||||||
|
std::optional<std::string> format,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (!format.has_value()) {
|
||||||
|
std::string fname;
|
||||||
|
if (py::isinstance<py::str>(file)) {
|
||||||
|
fname = py::cast<std::string>(file);
|
||||||
|
} else if (is_istream_object(file)) {
|
||||||
|
fname = file.attr("name").cast<std::string>();
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[load] Input must be a file-like object, or string");
|
||||||
|
}
|
||||||
|
size_t ext = fname.find_last_of('.');
|
||||||
|
if (ext == std::string::npos) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[load] Could not infer file format from extension");
|
||||||
|
}
|
||||||
|
format.emplace(fname.substr(ext + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (format.value() == "safetensors") {
|
||||||
|
return mlx_load_safetensor_helper(file, s);
|
||||||
|
} else if (format.value() == "npz") {
|
||||||
|
return mlx_load_npz_helper(file, s);
|
||||||
|
} else if (format.value() == "npy") {
|
||||||
|
return mlx_load_npy_helper(file, s);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("[load] Unknown file format " + format.value());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -305,7 +364,7 @@ void mlx_save_helper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[save] Input must be a file-like object, string, or pathlib.Path");
|
"[save] Input must be a file-like object, or string");
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlx_savez_helper(
|
void mlx_savez_helper(
|
||||||
@ -361,3 +420,25 @@ void mlx_savez_helper(
|
|||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void mlx_save_safetensor_helper(
|
||||||
|
py::object file,
|
||||||
|
py::dict d,
|
||||||
|
std::optional<bool> retain_graph) {
|
||||||
|
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||||
|
if (py::isinstance<py::str>(file)) {
|
||||||
|
save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
|
||||||
|
return;
|
||||||
|
} else if (is_ostream_object(file)) {
|
||||||
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil;
|
||||||
|
save_safetensors(writer, arrays_map, retain_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[save_safetensors] Input must be a file-like object, or string");
|
||||||
|
}
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
@ -12,7 +14,18 @@ using namespace mlx::core;
|
|||||||
|
|
||||||
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||||
|
|
||||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||||
|
py::object file,
|
||||||
|
StreamOrDevice s);
|
||||||
|
void mlx_save_safetensor_helper(
|
||||||
|
py::object file,
|
||||||
|
py::dict d,
|
||||||
|
std::optional<bool> retain_graph = std::nullopt);
|
||||||
|
|
||||||
|
DictOrArray mlx_load_helper(
|
||||||
|
py::object file,
|
||||||
|
std::optional<std::string> format,
|
||||||
|
StreamOrDevice s);
|
||||||
void mlx_save_helper(
|
void mlx_save_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
array a,
|
array a,
|
||||||
|
@ -2867,11 +2867,9 @@ void init_ops(py::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
file (str): File to which the array is saved
|
file (str): File to which the array is saved
|
||||||
arr (array): Array to be saved.
|
arr (array): Array to be saved.
|
||||||
retain_graph (bool, optional): Optional argument to retain graph
|
retain_graph (bool, optional): Whether or not to retain the graph
|
||||||
during array evaluation before saving. If not provided the graph
|
during array evaluation. If left unspecified the graph is retained
|
||||||
is retained if we are during a function transformation. Default:
|
only if saving is done in a function transformation. Default: ``None``
|
||||||
None
|
|
||||||
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"savez",
|
"savez",
|
||||||
@ -2932,18 +2930,45 @@ void init_ops(py::module_& m) {
|
|||||||
&mlx_load_helper,
|
&mlx_load_helper,
|
||||||
"file"_a,
|
"file"_a,
|
||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
|
"format"_a = none,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||||
|
|
||||||
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
|
Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved
|
file (file, str): File in which the array is saved.
|
||||||
|
format (str, optional): Format of the file. If ``None``, the format
|
||||||
|
is inferred from the file extension. Supported formats: ``npy``,
|
||||||
|
``npz``, and ``safetensors``. Default: ``None``.
|
||||||
Returns:
|
Returns:
|
||||||
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
|
result (array, dict):
|
||||||
|
A single array if loading from a ``.npy`` file or a dict mapping
|
||||||
|
names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"save_safetensors",
|
||||||
|
&mlx_save_safetensor_helper,
|
||||||
|
"file"_a,
|
||||||
|
"arrays"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"retain_graph"_a = std::nullopt,
|
||||||
|
py::kw_only(),
|
||||||
|
R"pbdoc(
|
||||||
|
save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
|
||||||
|
|
||||||
|
Save array(s) to a binary file in ``.safetensors`` format.
|
||||||
|
|
||||||
|
For more information on the format see https://huggingface.co/docs/safetensors/index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (file, str): File in which the array is saved>
|
||||||
|
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||||
|
retain_graph (bool, optional): Whether or not to retain the graph
|
||||||
|
during array evaluation. If left unspecified the graph is retained
|
||||||
|
only if saving is done in a function transformation. Default: ``None``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"where",
|
"where",
|
||||||
|
@ -64,6 +64,33 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||||
|
|
||||||
|
def test_save_and_load_safetensors(self):
|
||||||
|
if not os.path.isdir(self.test_dir):
|
||||||
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
|
for dt in self.dtypes + ["bfloat16"]:
|
||||||
|
with self.subTest(dtype=dt):
|
||||||
|
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
save_file_mlx = os.path.join(
|
||||||
|
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
|
||||||
|
)
|
||||||
|
save_dict = {
|
||||||
|
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||||
|
if dt in ["float32", "float16", "bfloat16"]
|
||||||
|
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(save_file_mlx, "wb") as f:
|
||||||
|
mx.save_safetensors(f, save_dict)
|
||||||
|
with open(save_file_mlx, "rb") as f:
|
||||||
|
load_dict = mx.load(f)
|
||||||
|
|
||||||
|
self.assertTrue("test" in load_dict)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||||
|
)
|
||||||
|
|
||||||
def test_save_and_load_fs(self):
|
def test_save_and_load_fs(self):
|
||||||
|
|
||||||
if not os.path.isdir(self.test_dir):
|
if not os.path.isdir(self.test_dir):
|
||||||
|
@ -14,6 +14,26 @@ std::string get_temp_file(const std::string& name) {
|
|||||||
return std::filesystem::temp_directory_path().append(name);
|
return std::filesystem::temp_directory_path().append(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test save_safetensors") {
|
||||||
|
std::string file_path = get_temp_file("test_arr.safetensors");
|
||||||
|
auto map = std::unordered_map<std::string, array>();
|
||||||
|
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
||||||
|
map.insert({"test2", ones({2, 2})});
|
||||||
|
save_safetensors(file_path, map);
|
||||||
|
auto safeDict = load_safetensors(file_path);
|
||||||
|
CHECK_EQ(safeDict.size(), 2);
|
||||||
|
CHECK_EQ(safeDict.count("test"), 1);
|
||||||
|
CHECK_EQ(safeDict.count("test2"), 1);
|
||||||
|
array test = safeDict.at("test");
|
||||||
|
CHECK_EQ(test.dtype(), float32);
|
||||||
|
CHECK_EQ(test.shape(), std::vector<int>({4}));
|
||||||
|
CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>());
|
||||||
|
array test2 = safeDict.at("test2");
|
||||||
|
CHECK_EQ(test2.dtype(), float32);
|
||||||
|
CHECK_EQ(test2.shape(), std::vector<int>({2, 2}));
|
||||||
|
CHECK(array_equal(test2, ones({2, 2})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test single array serialization") {
|
TEST_CASE("test single array serialization") {
|
||||||
// Basic test
|
// Basic test
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user