Allocate raw JSON metadata buffer on the heap, and limit its size (#1596)

* Allocate raw JSON metadata buffer on the heap, and limit its size to 1GiB

* Set the upper size limit for the header to 100K as in Rust safetensors
This commit is contained in:
xnorai 2024-11-18 07:22:51 -08:00 committed by GitHub
parent 610af352d4
commit 16ec0556a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
// //
#include <json.hpp> #include <json.hpp>
#include <memory>
#include <stack> #include <stack>
#include "mlx/io.h" #include "mlx/io.h"
@ -109,15 +110,17 @@ SafetensorsLoad load_safetensors(
} }
uint64_t jsonHeaderLength = 0; uint64_t jsonHeaderLength = 0;
// This is the same limit as in the original Rust Safetensors code.
constexpr uint64_t kMaxJsonHeaderLength = 100000000;
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8); in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
if (jsonHeaderLength <= 0) { if (jsonHeaderLength <= 0 || jsonHeaderLength >= kMaxJsonHeaderLength) {
throw std::runtime_error( throw std::runtime_error(
"[load_safetensors] Invalid json header length " + in_stream->label()); "[load_safetensors] Invalid json header length " + in_stream->label());
} }
// Load the json metadata // Load the json metadata
char rawJson[jsonHeaderLength]; auto rawJson = std::make_unique<char[]>(jsonHeaderLength);
in_stream->read(rawJson, jsonHeaderLength); in_stream->read(rawJson.get(), jsonHeaderLength);
auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength); auto metadata = json::parse(rawJson.get(), rawJson.get() + jsonHeaderLength);
// Should always be an object on the top-level // Should always be an object on the top-level
if (!metadata.is_object()) { if (!metadata.is_object()) {
throw std::runtime_error( throw std::runtime_error(