don't duplicate malloc with custom kernel init (#1830)

This commit is contained in:
Awni Hannun 2025-02-04 13:20:57 -08:00 committed by GitHub
parent f6c0499b8d
commit a229c8cef0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,10 +15,11 @@ void CustomKernel::eval_gpu(
std::vector<array> copies;
for (auto& out : outputs) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
}