diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 9c251a944..900ae2c81 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -2,9 +2,20 @@ #include +#include "mlx/backend/cuda/cuda.h" +#include "mlx/backend/metal/metal.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" +inline mlx::core::Device get_device() { + if (mlx::core::metal::is_available()) { + return mlx::core::Device::cpu; + } else if (mlx::core::cu::is_available()) { + return mlx::core::Device::gpu; + } + throw std::runtime_error("No available device for distributed operations."); +} + namespace mlx::core::distributed { namespace { @@ -24,6 +35,7 @@ array all_sum( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); + auto dev = get_device(); if (group.size() == 1) { return x; @@ -31,7 +43,7 @@ array all_sum( return array( x.shape(), x.dtype(), - std::make_shared(to_stream(s), group, AllReduce::Sum), + std::make_shared(to_stream(s, dev), group, AllReduce::Sum), {x}); } @@ -40,6 +52,7 @@ array all_max( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); + auto dev = get_device(); if (group.size() == 1) { return x; @@ -47,8 +60,7 @@ array all_max( return array( x.shape(), x.dtype(), - std::make_shared( - to_stream(s, Device::cpu), group, AllReduce::Max), + std::make_shared(to_stream(s, dev), group, AllReduce::Max), {x}); } @@ -57,6 +69,7 @@ array all_min( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); + auto dev = get_device(); if (group.size() == 1) { return x; @@ -64,8 +77,7 @@ array all_min( return array( x.shape(), x.dtype(), - std::make_shared( - to_stream(s, Device::cpu), group, AllReduce::Min), + std::make_shared(to_stream(s, dev), group, AllReduce::Min), {x}); } @@ -74,6 +86,7 @@ array all_gather( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); + auto dev = get_device(); if (group.size() == 1) { return x; @@ -88,7 +101,7 @@ array all_gather( return array( std::move(result_shape), x.dtype(), - std::make_shared(to_stream(s, Device::cpu), group), + std::make_shared(to_stream(s, dev), group), {x}); } @@ -98,6 +111,7 @@ array send( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); + auto dev = get_device(); if (group.size() == 1) { throw std::invalid_argument("Cannot send to a singleton group"); @@ -113,7 +127,7 @@ array send( return array( x.shape(), x.dtype(), - std::make_shared(to_stream(s, Device::cpu), group, dst), + std::make_shared(to_stream(s, dev), group, dst), {x}); } @@ -124,6 +138,7 @@ array recv( std::optional group_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto group = to_group(group_); + auto dev = get_device(); if (group.size() == 1) { throw std::invalid_argument("Cannot recv from a singleton group"); @@ -138,7 +153,7 @@ array recv( return array( std::move(shape), std::move(dtype), - std::make_shared(to_stream(s, Device::cpu), group, src), + std::make_shared(to_stream(s, dev), group, src), std::vector{}); }