mpi send use input as output (#1750)

* mpi send use input as output

* move earlier
This commit is contained in:
Awni Hannun
2025-01-06 06:08:43 -08:00
committed by GitHub
parent eab93985b8
commit 058d6ce683
5 changed files with 19 additions and 3 deletions

View File

@@ -78,7 +78,10 @@ array send(
}
return array(
{0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
x.shape(),
x.dtype(),
std::make_shared<Send>(to_stream(s), group, dst),
{x});
}
array recv(

View File

@@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/ops.h"
@@ -105,6 +106,7 @@ void Send::eval_cpu(
assert(outputs.size() == 1);
distributed::detail::send(group(), inputs[0], dst_);
move_or_copy(inputs[0], outputs[0]);
}
std::pair<std::vector<array>, std::vector<int>> Send::vmap(