mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 06:29:09 +08:00
added get_device to do reductions on the cpu if metal
This commit is contained in:
parent
4ee0d0bb55
commit
3bb6b1d44a
@ -2,9 +2,20 @@
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
|
||||
inline mlx::core::Device get_device() {
|
||||
if (mlx::core::metal::is_available()) {
|
||||
return mlx::core::Device::cpu;
|
||||
} else if (mlx::core::cu::is_available()) {
|
||||
return mlx::core::Device::gpu;
|
||||
}
|
||||
throw std::runtime_error("No available device for distributed operations.");
|
||||
}
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace {
|
||||
@ -24,6 +35,7 @@ array all_sum(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
@ -31,7 +43,7 @@ array all_sum(
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
@ -40,6 +52,7 @@ array all_max(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
@ -47,8 +60,7 @@ array all_max(
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(
|
||||
to_stream(s, Device::cpu), group, AllReduce::Max),
|
||||
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Max),
|
||||
{x});
|
||||
}
|
||||
|
||||
@ -57,6 +69,7 @@ array all_min(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
@ -64,8 +77,7 @@ array all_min(
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(
|
||||
to_stream(s, Device::cpu), group, AllReduce::Min),
|
||||
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Min),
|
||||
{x});
|
||||
}
|
||||
|
||||
@ -74,6 +86,7 @@ array all_gather(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
@ -88,7 +101,7 @@ array all_gather(
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
|
||||
std::make_shared<AllGather>(to_stream(s, dev), group),
|
||||
{x});
|
||||
}
|
||||
|
||||
@ -98,6 +111,7 @@ array send(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
throw std::invalid_argument("Cannot send to a singleton group");
|
||||
@ -113,7 +127,7 @@ array send(
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
|
||||
std::make_shared<Send>(to_stream(s, dev), group, dst),
|
||||
{x});
|
||||
}
|
||||
|
||||
@ -124,6 +138,7 @@ array recv(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||
@ -138,7 +153,7 @@ array recv(
|
||||
return array(
|
||||
std::move(shape),
|
||||
std::move(dtype),
|
||||
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
|
||||
std::make_shared<Recv>(to_stream(s, dev), group, src),
|
||||
std::vector<array>{});
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user