mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
don't duplicate malloc with custom kernel init (#1830)
This commit is contained in:
parent
f6c0499b8d
commit
a229c8cef0
@ -15,10 +15,11 @@ void CustomKernel::eval_gpu(
|
|||||||
std::vector<array> copies;
|
std::vector<array> copies;
|
||||||
|
|
||||||
for (auto& out : outputs) {
|
for (auto& out : outputs) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
if (init_value_) {
|
if (init_value_) {
|
||||||
copies.emplace_back(init_value_.value(), out.dtype());
|
copies.emplace_back(init_value_.value(), out.dtype());
|
||||||
fill_gpu(copies.back(), out, s);
|
fill_gpu(copies.back(), out, s);
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user