mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
mpi send use input as output (#1750)
* mpi send use input as output * move earlier
This commit is contained in:
@@ -78,7 +78,10 @@ array send(
|
||||
}
|
||||
|
||||
return array(
|
||||
{0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<Send>(to_stream(s), group, dst),
|
||||
{x});
|
||||
}
|
||||
|
||||
array recv(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -105,6 +106,7 @@ void Send::eval_cpu(
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
distributed::detail::send(group(), inputs[0], dst_);
|
||||
move_or_copy(inputs[0], outputs[0]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||
|
||||
Reference in New Issue
Block a user