From 16ec0556a0427545955a6e6630b2d1ef57c39f3c Mon Sep 17 00:00:00 2001 From: xnorai <147757538+xnorai@users.noreply.github.com> Date: Mon, 18 Nov 2024 07:22:51 -0800 Subject: [PATCH] Allocate raw JSON metadata buffer on the heap, and limit its size (#1596) * Allocate raw JSON metadata buffer on the heap, and limit its size to 1GiB * Set the upper size limit for the header to 100K as in Rust safetensors --- mlx/io/safetensors.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index f022fb25f..5c4854186 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. // #include +#include #include #include "mlx/io.h" @@ -109,15 +110,17 @@ SafetensorsLoad load_safetensors( } uint64_t jsonHeaderLength = 0; + // This is the same limit as in the original Rust Safetensors code. + constexpr uint64_t kMaxJsonHeaderLength = 100000000; in_stream->read(reinterpret_cast(&jsonHeaderLength), 8); - if (jsonHeaderLength <= 0) { + if (jsonHeaderLength <= 0 || jsonHeaderLength >= kMaxJsonHeaderLength) { throw std::runtime_error( "[load_safetensors] Invalid json header length " + in_stream->label()); } // Load the json metadata - char rawJson[jsonHeaderLength]; - in_stream->read(rawJson, jsonHeaderLength); - auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength); + auto rawJson = std::make_unique(jsonHeaderLength); + in_stream->read(rawJson.get(), jsonHeaderLength); + auto metadata = json::parse(rawJson.get(), rawJson.get() + jsonHeaderLength); // Should always be an object on the top-level if (!metadata.is_object()) { throw std::runtime_error(