mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Remove Hazard tracking with Fences (#1509)
* remove hazard tracking * with fence map * no hazard tracking with fences * nits * fix fence retain * cleanup * fix quantized rebase
This commit is contained in:
@@ -226,13 +226,8 @@ void steel_matmul_regular(
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
// Record copies
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void steel_matmul(
|
||||
@@ -382,12 +377,7 @@ void steel_matmul(
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -435,8 +425,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
fill_gpu(zero, out, s);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
||||
d.add_temporary(std::move(zero), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -588,12 +577,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -798,12 +782,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -916,12 +895,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1056,12 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -1080,8 +1049,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
fill_gpu(zero, out, s);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
||||
d.add_temporary(std::move(zero), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1356,12 +1324,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1471,13 +1434,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -1496,8 +1453,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
fill_gpu(zero, out, s);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
||||
d.add_temporary(std::move(zero), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1703,12 +1659,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1847,13 +1798,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user