Remove Hazard tracking with Fences (#1509)

* remove hazard tracking

* with fence map

* no hazard tracking with fences

* nits

* fix fence retain

* cleanup

* fix quantized rebase
This commit is contained in:
Awni Hannun
2024-10-21 19:33:32 -07:00
committed by GitHub
parent d15fa13daf
commit c26208f67d
25 changed files with 268 additions and 299 deletions

View File

@@ -226,13 +226,8 @@ void steel_matmul_regular(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
// Record copies
d.add_temporaries(std::move(copies), s.index);
}
void steel_matmul(
@@ -382,12 +377,7 @@ void steel_matmul(
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -435,8 +425,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
d.add_temporary(std::move(zero), s.index);
return;
}
@@ -588,12 +577,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
/////////////////////////////////////////////////////////////////////////////
@@ -798,12 +782,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -916,12 +895,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -1056,12 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -1080,8 +1049,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
d.add_temporary(std::move(zero), s.index);
return;
}
@@ -1356,12 +1324,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -1471,13 +1434,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -1496,8 +1453,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
d.add_temporary(std::move(zero), s.index);
return;
}
@@ -1703,12 +1659,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -1847,13 +1798,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core