mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Fix ring of 2 and allow scalars in API (#1906)
This commit is contained in:

committed by
GitHub

parent
7d042f17fe
commit
6bf00ef631
@@ -10,6 +10,8 @@
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
@@ -86,7 +88,11 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"all_sum",
|
||||
&mx::distributed::all_sum,
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::all_sum(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
@@ -112,7 +118,11 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"all_gather",
|
||||
&mx::distributed::all_gather,
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::all_gather(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
@@ -139,7 +149,12 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"send",
|
||||
&mx::distributed::send,
|
||||
[](const ScalarOrArray& x,
|
||||
int dst,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::send(to_array(x), dst, group, s);
|
||||
},
|
||||
"x"_a,
|
||||
"dst"_a,
|
||||
nb::kw_only(),
|
||||
@@ -195,7 +210,12 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"recv_like",
|
||||
&mx::distributed::recv_like,
|
||||
[](const ScalarOrArray& x,
|
||||
int src,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::recv_like(to_array(x), src, group, s);
|
||||
},
|
||||
"x"_a,
|
||||
"src"_a,
|
||||
nb::kw_only(),
|
||||
|
Reference in New Issue
Block a user