MLX
Loading...
Searching...
No Matches
distributed.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <memory>
6
7#include "mlx/array.h"
8
10
11/* Check if a communication backend is available */
13
19struct Group {
20 Group(std::shared_ptr<void> group) : group_(group) {}
21
22 int rank();
23 int size();
24
33 Group split(int color, int key = -1);
34
35 const std::shared_ptr<void>& raw_group() {
36 return group_;
37 }
38
39 private:
40 std::shared_ptr<void> group_{nullptr};
41};
42
51Group init(bool strict = false);
52
53namespace detail {
54
55/* Return the communication stream. */
57
58/* Perform an all reduce sum operation */
59void all_sum(Group group, const array& input, array& output);
60
61/* Perform an all reduce sum operation */
62void all_gather(Group group, const array& input, array& output);
63
64} // namespace detail
65
66} // namespace mlx::core::distributed
Definition array.h:20
void all_sum(Group group, const array &input, array &output)
void all_gather(Group group, const array &input, array &output)
Definition distributed.h:9
Group init(bool strict=false)
Initialize the distributed backend and return the group containing all discoverable processes.
Definition stream.h:9
A distributed::Group represents a group of independent mlx processes that can communicate.
Definition distributed.h:19
const std::shared_ptr< void > & raw_group()
Definition distributed.h:35
Group(std::shared_ptr< void > group)
Definition distributed.h:20
Group split(int color, int key=-1)
Split the group according to the provided color.