mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Allocate raw JSON metadata buffer on the heap, and limit its size (#1596)
* Allocate raw JSON metadata buffer on the heap, and limit its size to 1GiB * Set the upper size limit for the header to 100K as in Rust safetensors
This commit is contained in:
parent
610af352d4
commit
16ec0556a0
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
//
|
//
|
||||||
#include <json.hpp>
|
#include <json.hpp>
|
||||||
|
#include <memory>
|
||||||
#include <stack>
|
#include <stack>
|
||||||
|
|
||||||
#include "mlx/io.h"
|
#include "mlx/io.h"
|
||||||
@ -109,15 +110,17 @@ SafetensorsLoad load_safetensors(
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint64_t jsonHeaderLength = 0;
|
uint64_t jsonHeaderLength = 0;
|
||||||
|
// This is the same limit as in the original Rust Safetensors code.
|
||||||
|
constexpr uint64_t kMaxJsonHeaderLength = 100000000;
|
||||||
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
|
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
|
||||||
if (jsonHeaderLength <= 0) {
|
if (jsonHeaderLength <= 0 || jsonHeaderLength >= kMaxJsonHeaderLength) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[load_safetensors] Invalid json header length " + in_stream->label());
|
"[load_safetensors] Invalid json header length " + in_stream->label());
|
||||||
}
|
}
|
||||||
// Load the json metadata
|
// Load the json metadata
|
||||||
char rawJson[jsonHeaderLength];
|
auto rawJson = std::make_unique<char[]>(jsonHeaderLength);
|
||||||
in_stream->read(rawJson, jsonHeaderLength);
|
in_stream->read(rawJson.get(), jsonHeaderLength);
|
||||||
auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength);
|
auto metadata = json::parse(rawJson.get(), rawJson.get() + jsonHeaderLength);
|
||||||
// Should always be an object on the top-level
|
// Should always be an object on the top-level
|
||||||
if (!metadata.is_object()) {
|
if (!metadata.is_object()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
Loading…
Reference in New Issue
Block a user