Safetensor support (#215)

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo
2023-12-27 05:06:55 -05:00
committed by GitHub
parent 6b0d30bb85
commit 1f6ab6a556
17 changed files with 476 additions and 52 deletions

6
mlx/io/CMakeLists.txt Normal file
View File

@@ -0,0 +1,6 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
)

242
mlx/io/load.cpp Normal file
View File

@@ -0,0 +1,242 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cstring>
#include <fstream>
#include <limits>
#include <sstream>
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
// Adapted from
// https://github.com/angeloskath/supervised-lda/blob/master/include/ldaplusplus/NumpyFormat.hpp
namespace mlx::core {
namespace {
static constexpr uint8_t MAGIC[] = {
0x93,
0x4e,
0x55,
0x4d,
0x50,
0x59,
};
inline bool is_big_endian_() {
union ByteOrder {
int32_t i;
uint8_t c[4];
};
ByteOrder b = {0x01234567};
return b.c[0] == 0x01;
}
} // namespace
/** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
////////////////////////////////////////////////////////
// Check array
a.eval(retain_graph);
if (a.nbytes() == 0) {
throw std::invalid_argument("[save] cannot serialize an empty array");
}
if (!a.flags().contiguous) {
throw std::invalid_argument(
"[save] cannot serialize a non-contiguous array");
}
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
throw std::runtime_error("[save] Failed to open " + out_stream->label());
}
////////////////////////////////////////////////////////
// Prepare header
std::ostringstream magic_ver_len;
magic_ver_len.write(reinterpret_cast<const char*>(MAGIC), 6);
std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
std::ostringstream header;
header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
<< " 'fortran_order': " << fortran_order << ","
<< " 'shape': (";
for (auto i : a.shape()) {
header << i << ", ";
}
header << ")}";
size_t header_len = static_cast<size_t>(header.tellp());
bool is_v1 = header_len + 15 < std::numeric_limits<uint16_t>::max();
// Pad out magic + version + header_len + header + \n to be divisible by 16
size_t padding = (6 + 2 + (2 + 2 * is_v1) + header_len + 1) % 16;
header << std::string(padding, ' ') << '\n';
if (is_v1) {
magic_ver_len << (char)0x01 << (char)0x00;
uint16_t v1_header_len = header.tellp();
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
if (!is_big_endian_()) {
magic_ver_len.write(len_bytes, 2);
} else {
magic_ver_len.write(len_bytes + 1, 1);
magic_ver_len.write(len_bytes, 1);
}
} else {
magic_ver_len << (char)0x02 << (char)0x00;
uint32_t v2_header_len = header.tellp();
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
if (!is_big_endian_()) {
magic_ver_len.write(len_bytes, 4);
} else {
magic_ver_len.write(len_bytes + 3, 1);
magic_ver_len.write(len_bytes + 2, 1);
magic_ver_len.write(len_bytes + 1, 1);
magic_ver_len.write(len_bytes, 1);
}
}
////////////////////////////////////////////////////////
// Serialize array
out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
out_stream->write(header.str().c_str(), header.str().length());
out_stream->write(a.data<char>(), a.nbytes());
return;
}
/** Save array to file in .npy format */
void save(const std::string& file_, array a, bool retain_graph) {
// Open and check file
std::string file = file_;
// Add .npy to file name if it is not there
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
file += ".npy";
// Serialize array
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
}
/** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
////////////////////////////////////////////////////////
// Open and check file
if (!in_stream->good() || !in_stream->is_open()) {
throw std::runtime_error("[load] Failed to open " + in_stream->label());
}
////////////////////////////////////////////////////////
// Read header and prepare array details
// Read and check magic
char read_magic_and_ver[8];
in_stream->read(read_magic_and_ver, 8);
if (std::memcmp(read_magic_and_ver, MAGIC, 6) != 0) {
throw std::runtime_error("[load] Invalid header in " + in_stream->label());
}
// Read and check version
if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {
throw std::runtime_error(
"[load] Unsupport npy format version in " + in_stream->label());
}
// Read header len and header
int header_len_size = read_magic_and_ver[6] == 1 ? 2 : 4;
size_t header_len;
if (header_len_size == 2) {
uint16_t v1_header_len;
in_stream->read(reinterpret_cast<char*>(&v1_header_len), header_len_size);
header_len = v1_header_len;
} else {
uint32_t v2_header_len;
in_stream->read(reinterpret_cast<char*>(&v2_header_len), header_len_size);
header_len = v2_header_len;
}
// Read the header
std::vector<char> buffer(header_len + 1);
in_stream->read(&buffer[0], header_len);
buffer[header_len] = 0;
std::string header(&buffer[0]);
// Read data type from header
std::string dtype_str = header.substr(11, 3);
bool read_is_big_endian = dtype_str[0] == '>';
Dtype dtype = dtype_from_array_protocol(dtype_str);
// Read contiguity order
bool col_contiguous = header[34] == 'T';
// Read array shape from header
std::vector<int> shape;
size_t st = header.find_last_of('(') + 1;
size_t ed = header.find_last_of(')');
std::string shape_str = header.substr(st, ed - st);
while (!shape_str.empty()) {
// Read current number and get position of comma
size_t pos;
int dim = std::stoi(shape_str, &pos);
shape.push_back(dim);
// Skip the comma and space and read the next number
if (pos + 2 <= shape_str.length())
shape_str = shape_str.substr(pos + 2);
else {
shape_str = shape_str.substr(pos);
if (!shape_str.empty() && shape_str != " " && shape_str != ",") {
throw std::runtime_error(
"[load] Unknown error while parsing header in " +
in_stream->label());
}
shape_str = "";
}
}
////////////////////////////////////////////////////////
// Build primitive
size_t offset = 8 + header_len_size + header.length();
bool swap_endianness = read_is_big_endian != is_big_endian_();
if (col_contiguous) {
std::reverse(shape.begin(), shape.end());
}
auto loaded_array = array(
shape,
dtype,
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
std::vector<array>{});
if (col_contiguous) {
loaded_array = transpose(loaded_array, s);
}
return loaded_array;
}
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s) {
return load(std::make_shared<io::FileReader>(file), s);
}
} // namespace mlx::core

114
mlx/io/load.h Normal file
View File

@@ -0,0 +1,114 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <fstream>
#include <istream>
#include <memory>
namespace mlx::core {
namespace io {
class Reader {
public:
virtual bool is_open() const = 0;
virtual bool good() const = 0;
virtual size_t tell() const = 0;
virtual void seek(
int64_t off,
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void read(char* data, size_t n) = 0;
virtual std::string label() const = 0;
};
class Writer {
public:
virtual bool is_open() const = 0;
virtual bool good() const = 0;
virtual size_t tell() const = 0;
virtual void seek(
int64_t off,
std::ios_base::seekdir way = std::ios_base::beg) = 0;
virtual void write(const char* data, size_t n) = 0;
virtual std::string label() const = 0;
};
class FileReader : public Reader {
public:
explicit FileReader(const std::shared_ptr<std::ifstream>& is)
: is_(is), label_("stream") {}
explicit FileReader(const std::string& file_path)
: is_(std::make_shared<std::ifstream>(file_path, std::ios::binary)),
label_(file_path) {}
bool is_open() const override {
return is_->is_open();
}
bool good() const override {
return is_->good();
}
size_t tell() const override {
return is_->tellg();
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
is_->seekg(off, way);
}
void read(char* data, size_t n) override {
is_->read(data, n);
}
std::string label() const override {
return "file " + label_;
}
private:
std::shared_ptr<std::ifstream> is_;
std::string label_;
};
class FileWriter : public Writer {
public:
explicit FileWriter(const std::shared_ptr<std::ofstream>& is)
: os_(is), label_("stream") {}
explicit FileWriter(const std::string& file_path)
: os_(std::make_shared<std::ofstream>(file_path, std::ios::binary)),
label_(file_path) {}
bool is_open() const override {
return os_->is_open();
}
bool good() const override {
return os_->good();
}
size_t tell() const override {
return os_->tellp();
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
os_->seekp(off, way);
}
void write(const char* data, size_t n) override {
os_->write(data, n);
}
std::string label() const override {
return "file " + label_;
}
private:
std::shared_ptr<std::ofstream> os_;
std::string label_;
};
} // namespace io
} // namespace mlx::core

189
mlx/io/safetensor.cpp Normal file
View File

@@ -0,0 +1,189 @@
#include "mlx/io/safetensor.h"
#include <stack>
namespace mlx::core {
std::string dtype_to_safetensor_str(Dtype t) {
switch (t) {
case float32:
return ST_F32;
case bfloat16:
return ST_BF16;
case float16:
return ST_F16;
case int64:
return ST_I64;
case int32:
return ST_I32;
case int16:
return ST_I16;
case int8:
return ST_I8;
case uint64:
return ST_U64;
case uint32:
return ST_U32;
case uint16:
return ST_U16;
case uint8:
return ST_U8;
case bool_:
return ST_BOOL;
case complex64:
return ST_C64;
}
}
Dtype dtype_from_safetensor_str(std::string str) {
if (str == ST_F32) {
return float32;
} else if (str == ST_F16) {
return float16;
} else if (str == ST_BF16) {
return bfloat16;
} else if (str == ST_I64) {
return int64;
} else if (str == ST_I32) {
return int32;
} else if (str == ST_I16) {
return int16;
} else if (str == ST_I8) {
return int8;
} else if (str == ST_U64) {
return uint64;
} else if (str == ST_U32) {
return uint32;
} else if (str == ST_U16) {
return uint16;
} else if (str == ST_U8) {
return uint8;
} else if (str == ST_BOOL) {
return bool_;
} else if (str == ST_C64) {
return complex64;
} else {
throw std::runtime_error("[safetensor] unsupported dtype " + str);
}
}
/** Load array from reader in safetensor format */
std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
////////////////////////////////////////////////////////
// Open and check file
if (!in_stream->good() || !in_stream->is_open()) {
throw std::runtime_error(
"[load_safetensors] Failed to open " + in_stream->label());
}
uint64_t jsonHeaderLength = 0;
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
if (jsonHeaderLength <= 0) {
throw std::runtime_error(
"[load_safetensors] Invalid json header length " + in_stream->label());
}
// Load the json metadata
char rawJson[jsonHeaderLength];
in_stream->read(rawJson, jsonHeaderLength);
auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength);
// Should always be an object on the top-level
if (!metadata.is_object()) {
throw std::runtime_error(
"[load_safetensors] Invalid json metadata " + in_stream->label());
}
size_t offset = jsonHeaderLength + 8;
// Load the arrays using metadata
std::unordered_map<std::string, array> res;
for (const auto& item : metadata.items()) {
if (item.key() == "__metadata__") {
// ignore metadata for now
continue;
}
std::string dtype = item.value().at("dtype");
std::vector<int> shape = item.value().at("shape");
std::vector<size_t> data_offsets = item.value().at("data_offsets");
Dtype type = dtype_from_safetensor_str(dtype);
auto loaded_array = array(
shape,
type,
std::make_unique<Load>(
to_stream(s), in_stream, offset + data_offsets.at(0), false),
std::vector<array>{});
res.insert({item.key(), loaded_array});
}
return res;
}
std::unordered_map<std::string, array> load_safetensors(
const std::string& file,
StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s);
}
/** Save array to out stream in .npy format */
void save_safetensors(
std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph_) {
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
throw std::runtime_error(
"[save_safetensors] Failed to open " + out_stream->label());
}
////////////////////////////////////////////////////////
// Check array map
json parent;
parent["__metadata__"] = json::object({
{"format", "mlx"},
});
size_t offset = 0;
for (auto& [key, arr] : a) {
arr.eval(retain_graph_.value_or(arr.is_tracer()));
if (arr.nbytes() == 0) {
throw std::invalid_argument(
"[save_safetensors] cannot serialize an empty array key: " + key);
}
if (!arr.flags().contiguous) {
throw std::invalid_argument(
"[save_safetensors] cannot serialize a non-contiguous array key: " +
key);
}
json child;
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
child["shape"] = arr.shape();
child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()};
parent[key] = child;
offset += arr.nbytes();
}
auto header = parent.dump();
uint64_t header_len = header.length();
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
out_stream->write(header.c_str(), header_len);
for (auto& [key, arr] : a) {
out_stream->write(arr.data<char>(), arr.nbytes());
}
}
void save_safetensors(
const std::string& file_,
std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph) {
// Open and check file
std::string file = file_;
// Add .safetensors to file name if it is not there
if (file.length() < 12 ||
file.substr(file.length() - 12, 12) != ".safetensors")
file += ".safetensors";
// Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
}
} // namespace mlx::core

32
mlx/io/safetensor.h Normal file
View File

@@ -0,0 +1,32 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <json.hpp>
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
using json = nlohmann::json;
namespace mlx::core {
#define ST_F16 "F16"
#define ST_BF16 "BF16"
#define ST_F32 "F32"
#define ST_BOOL "BOOL"
#define ST_I8 "I8"
#define ST_I16 "I16"
#define ST_I32 "I32"
#define ST_I64 "I64"
#define ST_U8 "U8"
#define ST_U16 "U16"
#define ST_U32 "U32"
#define ST_U64 "U64"
// Note: Complex numbers aren't in the spec yet so this could change -
// https://github.com/huggingface/safetensors/issues/389
#define ST_C64 "C64"
} // namespace mlx::core