mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix copying scalars by adding fill_gpu (#1402)
* fix copying scalars by adding fill_gpu * Another copy scalar changed to fill --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -526,7 +526,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Return 0s if either input is empty
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
copy_gpu(zero, out, CopyType::Scalar, s);
|
||||
fill_gpu(zero, out, s);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
||||
return;
|
||||
@@ -1156,7 +1156,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Return 0s if either input is empty
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
copy_gpu(zero, out, CopyType::Scalar, s);
|
||||
fill_gpu(zero, out, s);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
||||
return;
|
||||
@@ -1565,7 +1565,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Return 0s if either input is empty
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
copy_gpu(zero, out, CopyType::Scalar, s);
|
||||
fill_gpu(zero, out, s);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user