Files
mlx/mlx/io/safetensors.cpp
Awni Hannun df58b4133a
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
[CUDA] Reduce use of managed memory (#2725)
* Use async cuda malloc managed with cuda 13

* add pool threshold

* refactor for regular cuda malloc

* load eval gpu for cuda

* remove use of cuda pool, use cuda free async

* fix

* fix

* fix

* fix

* fix + comment
2025-11-05 16:05:23 -08:00

238 lines
6.6 KiB
C++

// Copyright © 2023 Apple Inc.
//
#include <json.hpp>
#include <memory>
#include <stack>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/io.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
using json = nlohmann::json;
#define ST_F16 "F16"
#define ST_BF16 "BF16"
#define ST_F32 "F32"
#define ST_BOOL "BOOL"
#define ST_I8 "I8"
#define ST_I16 "I16"
#define ST_I32 "I32"
#define ST_I64 "I64"
#define ST_U8 "U8"
#define ST_U16 "U16"
#define ST_U32 "U32"
#define ST_U64 "U64"
#define ST_F8_E4M3 "F8_E4M3"
// Note: Complex numbers aren't in the spec yet so this could change -
// https://github.com/huggingface/safetensors/issues/389
#define ST_C64 "C64"
namespace mlx::core {
std::string dtype_to_safetensor_str(Dtype t) {
switch (t) {
case float32:
return ST_F32;
case bfloat16:
return ST_BF16;
case float16:
return ST_F16;
case int64:
return ST_I64;
case int32:
return ST_I32;
case int16:
return ST_I16;
case int8:
return ST_I8;
case uint64:
return ST_U64;
case uint32:
return ST_U32;
case uint16:
return ST_U16;
case uint8:
return ST_U8;
case bool_:
return ST_BOOL;
case complex64:
return ST_C64;
default:
throw std::runtime_error("[save_safetensors] received invalid dtype.");
}
}
Dtype dtype_from_safetensor_str(std::string_view str) {
if (str == ST_F32) {
return float32;
} else if (str == ST_F16) {
return float16;
} else if (str == ST_BF16) {
return bfloat16;
} else if (str == ST_I64) {
return int64;
} else if (str == ST_I32) {
return int32;
} else if (str == ST_I16) {
return int16;
} else if (str == ST_I8) {
return int8;
} else if (str == ST_U64) {
return uint64;
} else if (str == ST_U32) {
return uint32;
} else if (str == ST_U16) {
return uint16;
} else if (str == ST_U8) {
return uint8;
} else if (str == ST_BOOL) {
return bool_;
} else if (str == ST_C64) {
return complex64;
} else if (str == ST_F8_E4M3) {
// We convert this manually later
return uint8;
} else {
throw std::runtime_error(
"[safetensor] unsupported dtype " + std::string(str));
}
}
/** Load array from reader in safetensor format */
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
////////////////////////////////////////////////////////
// Open and check file
if (!in_stream->good() || !in_stream->is_open()) {
throw std::runtime_error(
"[load_safetensors] Failed to open " + in_stream->label());
}
auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);
uint64_t jsonHeaderLength = 0;
// This is the same limit as in the original Rust Safetensors code.
constexpr uint64_t kMaxJsonHeaderLength = 100000000;
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
if (jsonHeaderLength <= 0 || jsonHeaderLength >= kMaxJsonHeaderLength) {
throw std::runtime_error(
"[load_safetensors] Invalid json header length " + in_stream->label());
}
// Load the json metadata
auto rawJson = std::make_unique<char[]>(jsonHeaderLength);
in_stream->read(rawJson.get(), jsonHeaderLength);
auto metadata = json::parse(rawJson.get(), rawJson.get() + jsonHeaderLength);
// Should always be an object on the top-level
if (!metadata.is_object()) {
throw std::runtime_error(
"[load_safetensors] Invalid json metadata " + in_stream->label());
}
size_t offset = jsonHeaderLength + 8;
// Load the arrays using metadata
std::unordered_map<std::string, array> res;
std::unordered_map<std::string, std::string> metadata_map;
for (const auto& item : metadata.items()) {
if (item.key() == "__metadata__") {
for (const auto& meta_item : item.value().items()) {
metadata_map.insert({meta_item.key(), meta_item.value()});
}
continue;
}
const std::string& dtype = item.value().at("dtype");
const Shape& shape = item.value().at("shape");
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
Dtype type = dtype_from_safetensor_str(dtype);
auto loaded_array = array(
shape,
type,
std::make_shared<Load>(
stream, in_stream, offset + data_offsets.at(0), false),
std::vector<array>{});
if (dtype == ST_F8_E4M3) {
loaded_array = from_fp8(loaded_array, bfloat16, s);
}
res.insert({item.key(), loaded_array});
}
return {res, metadata_map};
}
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
return load_safetensors(std::make_shared<io::ParallelFileReader>(file), s);
}
void save_safetensors(
std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
throw std::runtime_error(
"[save_safetensors] Failed to open " + out_stream->label());
}
////////////////////////////////////////////////////////
// Check array map
json parent;
json _metadata;
for (auto& [key, value] : metadata) {
_metadata[key] = value;
}
parent["__metadata__"] = _metadata;
{
std::vector<array> to_eval;
to_eval.reserve(a.size());
for (auto& p : a) {
p.second = contiguous(p.second);
to_eval.push_back(p.second);
}
eval(std::move(to_eval));
}
size_t offset = 0;
for (auto& [key, arr] : a) {
if (arr.nbytes() == 0) {
throw std::invalid_argument(
"[save_safetensors] cannot serialize an empty array key: " + key);
}
json child;
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
child["shape"] = arr.shape();
child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()};
parent[key] = child;
offset += arr.nbytes();
}
auto header = parent.dump();
uint64_t header_len = header.length();
out_stream->write(reinterpret_cast<char*>(&header_len), 8);
out_stream->write(header.c_str(), header_len);
for (auto& [key, arr] : a) {
out_stream->write(arr.data<char>(), arr.nbytes());
}
}
void save_safetensors(
std::string file,
std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
// Add .safetensors to file name if it is not there
if (file.length() < 12 ||
file.substr(file.length() - 12, 12) != ".safetensors")
file += ".safetensors";
// Serialize array
save_safetensors(
std::make_shared<io::FileWriter>(std::move(file)), a, metadata);
}
} // namespace mlx::core