Fix normalization check_input (#1452)

This commit is contained in:
Angelos Katharopoulos
2024-10-03 13:26:56 -07:00
committed by GitHub
parent 5900e3249f
commit d878015228
11 changed files with 147 additions and 51 deletions

View File

@@ -227,9 +227,12 @@ void steel_matmul_conv_groups(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
}
void steel_matmul(
@@ -379,8 +382,12 @@ void steel_matmul(
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return;
}
@@ -507,9 +514,12 @@ void steel_matmul(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
}
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -680,8 +690,12 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return;
}
/////////////////////////////////////////////////////////////////////////////
@@ -886,8 +900,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return;
}
@@ -1000,8 +1018,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return;
}
@@ -1136,9 +1158,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
}
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -1433,8 +1458,12 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return;
}
@@ -1545,9 +1574,12 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
}
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -1773,8 +1805,12 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return;
}
@@ -1914,9 +1950,12 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
}
} // namespace mlx::core