fix: possible heap-buffer-overflow in RandomBits::eval_cpu (#2877)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled

This commit is contained in:
Melissa Kilby
2025-12-12 02:11:18 -08:00
committed by GitHub
parent f3e5ca5414
commit ccaaa7d6df

View File

@@ -291,6 +291,17 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys, num_keys,
kshape = keys.shape(), kshape = keys.shape(),
kstrides = keys.strides()]() mutable { kstrides = keys.strides()]() mutable {
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
if (4 * loc + 4 <= bytes_per_key) {
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
} else {
std::copy(
reinterpret_cast<char*>(&v),
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
cptr + 4 * loc);
}
};
size_t out_skip = (bytes_per_key + 4 - 1) / 4; size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2; auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0; bool even = out_skip % 2 == 0;
@@ -310,18 +321,12 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
if (count.first < half_size) { if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count); auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first; ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) { copy_remaining(cptr, count.second, rb.second);
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
} }
if (!even) { if (!even) {
count.second = 0; count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first; copy_remaining(
cptr, half_size, random::threefry2x32_hash(key, count).first);
} }
} }
}); });