MLX
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6#include "mlx/primitives.h"
7
9
10class DistPrimitive : public Primitive {
11 public:
13 : Primitive(detail::communication_stream()), group_(group) {}
14
15 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
16 override {
17 throw std::runtime_error(
18 "Communication primitives cannot be run on the GPU");
19 }
20
21 const Group& group() const {
22 return group_;
23 }
24
25 private:
26 Group group_;
27};
28
29class AllReduce : public DistPrimitive {
30 public:
31 enum ReduceType { And, Or, Sum, Prod, Min, Max };
32
34 : DistPrimitive(group), reduce_type_(reduce_type) {}
35
36 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
37 override;
38 std::pair<std::vector<array>, std::vector<int>> vmap(
39 const std::vector<array>& inputs,
40 const std::vector<int>& axes) override;
41 std::vector<array> jvp(
42 const std::vector<array>& primals,
43 const std::vector<array>& tangents,
44 const std::vector<int>& argnums) override;
45 std::vector<array> vjp(
46 const std::vector<array>& primals,
47 const std::vector<array>& cotangents,
48 const std::vector<int>& argnums,
49 const std::vector<array>& outputs) override;
50
51 void print(std::ostream& os) override {
52 switch (reduce_type_) {
53 case And:
54 os << "And";
55 case Or:
56 os << "And";
57 break;
58 case Sum:
59 os << "Sum";
60 break;
61 case Prod:
62 os << "Prod";
63 break;
64 case Min:
65 os << "Min";
66 break;
67 case Max:
68 os << "Max";
69 break;
70 }
71 os << " AllReduce";
72 }
73
74 private:
75 ReduceType reduce_type_;
76};
77
78class AllGather : public DistPrimitive {
79 public:
81
82 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
83 override;
84 std::pair<std::vector<array>, std::vector<int>> vmap(
85 const std::vector<array>& inputs,
86 const std::vector<int>& axes) override;
87 std::vector<array> jvp(
88 const std::vector<array>& primals,
89 const std::vector<array>& tangents,
90 const std::vector<int>& argnums) override;
91 std::vector<array> vjp(
92 const std::vector<array>& primals,
93 const std::vector<array>& cotangents,
94 const std::vector<int>& argnums,
95 const std::vector<array>& outputs) override;
96
98};
99
100} // namespace mlx::core::distributed
Definition primitives.h:48
Definition primitives.h:78
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
AllGather(Group group)
Definition primitives.h:80
std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
Definition primitives.h:29
std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:51
AllReduce(Group group, ReduceType reduce_type)
Definition primitives.h:33
ReduceType
Definition primitives.h:31
@ Sum
Definition primitives.h:31
@ Min
Definition primitives.h:31
@ Or
Definition primitives.h:31
@ And
Definition primitives.h:31
@ Max
Definition primitives.h:31
@ Prod
Definition primitives.h:31
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
Definition primitives.h:10
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:15
const Group & group() const
Definition primitives.h:21
DistPrimitive(Group group)
Definition primitives.h:12
Definition distributed.h:9
Definition ops.h:159
A distributed::Group represents a group of independent mlx processes that can communicate.
Definition distributed.h:19