mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix normalization check_input (#1452)
This commit is contained in:
committed by
GitHub
parent
5900e3249f
commit
d878015228
@@ -227,9 +227,12 @@ void steel_matmul_conv_groups(
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void steel_matmul(
|
||||
@@ -379,8 +382,12 @@ void steel_matmul(
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -507,9 +514,12 @@ void steel_matmul(
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -680,8 +690,12 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -886,8 +900,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1000,8 +1018,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1136,9 +1158,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -1433,8 +1458,12 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1545,9 +1574,12 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -1773,8 +1805,12 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1914,9 +1950,12 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user