mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 06:29:09 +08:00
added get_device to do reductions on the cpu if metal
This commit is contained in:
parent
4ee0d0bb55
commit
3bb6b1d44a
@ -2,9 +2,20 @@
|
|||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
|
|
||||||
|
inline mlx::core::Device get_device() {
|
||||||
|
if (mlx::core::metal::is_available()) {
|
||||||
|
return mlx::core::Device::cpu;
|
||||||
|
} else if (mlx::core::cu::is_available()) {
|
||||||
|
return mlx::core::Device::gpu;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("No available device for distributed operations.");
|
||||||
|
}
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -24,6 +35,7 @@ array all_sum(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
@ -31,7 +43,7 @@ array all_sum(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Sum),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,6 +52,7 @@ array all_max(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
@ -47,8 +60,7 @@ array all_max(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Max),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Max),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,6 +69,7 @@ array all_min(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
@ -64,8 +77,7 @@ array all_min(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Min),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Min),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,6 +86,7 @@ array all_gather(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
@ -88,7 +101,7 @@ array all_gather(
|
|||||||
return array(
|
return array(
|
||||||
std::move(result_shape),
|
std::move(result_shape),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
|
std::make_shared<AllGather>(to_stream(s, dev), group),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,6 +111,7 @@ array send(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot send to a singleton group");
|
throw std::invalid_argument("Cannot send to a singleton group");
|
||||||
@ -113,7 +127,7 @@ array send(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
|
std::make_shared<Send>(to_stream(s, dev), group, dst),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,6 +138,7 @@ array recv(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
auto dev = get_device();
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot recv from a singleton group");
|
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||||
@ -138,7 +153,7 @@ array recv(
|
|||||||
return array(
|
return array(
|
||||||
std::move(shape),
|
std::move(shape),
|
||||||
std::move(dtype),
|
std::move(dtype),
|
||||||
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
|
std::make_shared<Recv>(to_stream(s, dev), group, src),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user