Fix ring of 2 and allow scalars in API (#1906)

This commit is contained in:
Angelos Katharopoulos
2025-02-25 17:03:01 -08:00
committed by GitHub
parent 7d042f17fe
commit 6bf00ef631
2 changed files with 52 additions and 11 deletions

View File

@@ -10,6 +10,8 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
@@ -86,7 +88,11 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"all_sum",
&mx::distributed::all_sum,
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_sum(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
@@ -112,7 +118,11 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"all_gather",
&mx::distributed::all_gather,
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_gather(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
@@ -139,7 +149,12 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"send",
&mx::distributed::send,
[](const ScalarOrArray& x,
int dst,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::send(to_array(x), dst, group, s);
},
"x"_a,
"dst"_a,
nb::kw_only(),
@@ -195,7 +210,12 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"recv_like",
&mx::distributed::recv_like,
[](const ScalarOrArray& x,
int src,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::recv_like(to_array(x), src, group, s);
},
"x"_a,
"src"_a,
nb::kw_only(),