mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix malloc or wait deadlock (#1976)
This commit is contained in:
@@ -248,9 +248,9 @@ void sdpa_vector_2pass(
|
||||
intermediate_shape.pop_back();
|
||||
array sums(intermediate_shape, float32, nullptr, {});
|
||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
d.add_temporary(sums, s.index);
|
||||
d.add_temporary(maxs, s.index);
|
||||
@@ -383,7 +383,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
o.copy_shared_buffer(q);
|
||||
} else {
|
||||
if (o.shape(2) == 1) {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
o.set_data(allocator::malloc(o.nbytes()));
|
||||
} else {
|
||||
auto strides = o.strides();
|
||||
strides[2] = o.shape(1) * o.shape(3);
|
||||
@@ -391,10 +391,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
auto flags = q.flags();
|
||||
flags.row_contiguous = q.shape(1) == 1;
|
||||
o.set_data(
|
||||
allocator::malloc_or_wait(o.nbytes()),
|
||||
o.size(),
|
||||
std::move(strides),
|
||||
flags);
|
||||
allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -432,7 +429,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc_or_wait(o.nbytes()),
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
Reference in New Issue
Block a user