fix malloc or wait deadlock (#1976)

This commit is contained in:
Awni Hannun
2025-03-20 16:48:43 -07:00
committed by GitHub
parent 1177d28395
commit 7b7e2352cd
55 changed files with 201 additions and 217 deletions

View File

@@ -248,9 +248,9 @@ void sdpa_vector_2pass(
intermediate_shape.pop_back();
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
sums.set_data(allocator::malloc(sums.nbytes()));
maxs.set_data(allocator::malloc(maxs.nbytes()));
d.add_temporary(intermediate, s.index);
d.add_temporary(sums, s.index);
d.add_temporary(maxs, s.index);
@@ -383,7 +383,7 @@ void ScaledDotProductAttention::eval_gpu(
o.copy_shared_buffer(q);
} else {
if (o.shape(2) == 1) {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
o.set_data(allocator::malloc(o.nbytes()));
} else {
auto strides = o.strides();
strides[2] = o.shape(1) * o.shape(3);
@@ -391,10 +391,7 @@ void ScaledDotProductAttention::eval_gpu(
auto flags = q.flags();
flags.row_contiguous = q.shape(1) == 1;
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
o.size(),
std::move(strides),
flags);
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
}
}
@@ -432,7 +429,7 @@ void ScaledDotProductAttention::eval_gpu(
};
o.set_data(
allocator::malloc_or_wait(o.nbytes()),
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);