From ccaaa7d6df1742111e6784adf7c53220c33706a9 Mon Sep 17 00:00:00 2001 From: Melissa Kilby Date: Fri, 12 Dec 2025 02:11:18 -0800 Subject: [PATCH] fix: possible heap-buffer-overflow in RandomBits::eval_cpu (#2877) --- mlx/backend/cpu/primitives.cpp | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index d5b917b84..4e59b1ebe 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -291,6 +291,17 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { num_keys, kshape = keys.shape(), kstrides = keys.strides()]() mutable { + auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) { + if (4 * loc + 4 <= bytes_per_key) { + reinterpret_cast(cptr)[loc] = v; + } else { + std::copy( + reinterpret_cast(&v), + reinterpret_cast(&v) + bytes_per_key - 4 * loc, + cptr + 4 * loc); + } + }; + size_t out_skip = (bytes_per_key + 4 - 1) / 4; auto half_size = out_skip / 2; bool even = out_skip % 2 == 0; @@ -310,18 +321,12 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { if (count.first < half_size) { auto rb = random::threefry2x32_hash(key, count); ptr[count.first++] = rb.first; - if (bytes_per_key % 4 > 0) { - std::copy( - reinterpret_cast(&rb.second), - reinterpret_cast(&rb.second) + bytes_per_key % 4, - cptr + 4 * count.second); - } else { - ptr[count.second] = rb.second; - } + copy_remaining(cptr, count.second, rb.second); } if (!even) { count.second = 0; - ptr[half_size] = random::threefry2x32_hash(key, count).first; + copy_remaining( + cptr, half_size, random::threefry2x32_hash(key, count).first); } } });