added get_device to do reductions on the cpu if metal

This commit is contained in:
Anastasiia Filippova 2025-08-20 18:00:16 +02:00
parent 4ee0d0bb55
commit 3bb6b1d44a

View File

@ -2,9 +2,20 @@
#include <sstream>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
inline mlx::core::Device get_device() {
if (mlx::core::metal::is_available()) {
return mlx::core::Device::cpu;
} else if (mlx::core::cu::is_available()) {
return mlx::core::Device::gpu;
}
throw std::runtime_error("No available device for distributed operations.");
}
namespace mlx::core::distributed {
namespace {
@ -24,6 +35,7 @@ array all_sum(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
return x;
@ -31,7 +43,7 @@ array all_sum(
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Sum),
{x});
}
@ -40,6 +52,7 @@ array all_max(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
return x;
@ -47,8 +60,7 @@ array all_max(
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Max),
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Max),
{x});
}
@ -57,6 +69,7 @@ array all_min(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
return x;
@ -64,8 +77,7 @@ array all_min(
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Min),
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Min),
{x});
}
@ -74,6 +86,7 @@ array all_gather(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
return x;
@ -88,7 +101,7 @@ array all_gather(
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
std::make_shared<AllGather>(to_stream(s, dev), group),
{x});
}
@ -98,6 +111,7 @@ array send(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
throw std::invalid_argument("Cannot send to a singleton group");
@ -113,7 +127,7 @@ array send(
return array(
x.shape(),
x.dtype(),
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
std::make_shared<Send>(to_stream(s, dev), group, dst),
{x});
}
@ -124,6 +138,7 @@ array recv(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
throw std::invalid_argument("Cannot recv from a singleton group");
@ -138,7 +153,7 @@ array recv(
return array(
std::move(shape),
std::move(dtype),
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
std::make_shared<Recv>(to_stream(s, dev), group, src),
std::vector<array>{});
}