MLX
 
Loading...
Searching...
No Matches
primitives.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
7#include "mlx/primitives.h"
8
10
11class DistPrimitive : public Primitive {
12 public:
15
16 const Group& group() const {
17 return group_;
18 }
19
20 private:
21 Group group_;
22};
23
24class AllReduce : public DistPrimitive {
25 public:
26 enum ReduceType { And, Or, Sum, Prod, Min, Max };
27
29 : DistPrimitive(stream, group), reduce_type_(reduce_type) {}
30
31 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
32 override;
33 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
34 override;
35 std::pair<std::vector<array>, std::vector<int>> vmap(
36 const std::vector<array>& inputs,
37 const std::vector<int>& axes) override;
38 std::vector<array> jvp(
39 const std::vector<array>& primals,
40 const std::vector<array>& tangents,
41 const std::vector<int>& argnums) override;
42 std::vector<array> vjp(
43 const std::vector<array>& primals,
44 const std::vector<array>& cotangents,
45 const std::vector<int>& argnums,
46 const std::vector<array>& outputs) override;
47
48 void print(std::ostream& os) override {
49 switch (reduce_type_) {
50 case And:
51 os << "And";
52 case Or:
53 os << "And";
54 break;
55 case Sum:
56 os << "Sum";
57 break;
58 case Prod:
59 os << "Prod";
60 break;
61 case Min:
62 os << "Min";
63 break;
64 case Max:
65 os << "Max";
66 break;
67 }
68 os << " AllReduce";
69 }
70
71 private:
72 ReduceType reduce_type_;
73};
74
75class AllGather : public DistPrimitive {
76 public:
78
79 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
80 override;
81 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
82 override;
83
84 std::pair<std::vector<array>, std::vector<int>> vmap(
85 const std::vector<array>& inputs,
86 const std::vector<int>& axes) override;
87 std::vector<array> jvp(
88 const std::vector<array>& primals,
89 const std::vector<array>& tangents,
90 const std::vector<int>& argnums) override;
91 std::vector<array> vjp(
92 const std::vector<array>& primals,
93 const std::vector<array>& cotangents,
94 const std::vector<int>& argnums,
95 const std::vector<array>& outputs) override;
96
98};
99
100class Send : public DistPrimitive {
101 public:
103 : DistPrimitive(stream, group), dst_(dst) {}
104
105 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
106 override;
107 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
108 override;
109 std::pair<std::vector<array>, std::vector<int>> vmap(
110 const std::vector<array>& inputs,
111 const std::vector<int>& axes) override;
112
114
115 private:
116 int dst_;
117};
118
119class Recv : public DistPrimitive {
120 public:
122 : DistPrimitive(stream, group), src_(src) {}
123
124 void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
125 override;
126 void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
127 override;
128
130
131 private:
132 int src_;
133};
134
135} // namespace mlx::core::distributed
const Stream & stream()
The stream the primitive will run on.
Definition primitives.h:58
Primitive(Stream stream)
Definition primitives.h:50
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
AllGather(Stream stream, Group group)
Definition primitives.h:77
AllReduce(Stream stream, Group group, ReduceType reduce_type)
Definition primitives.h:28
std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:48
ReduceType
Definition primitives.h:26
@ Sum
Definition primitives.h:26
@ Min
Definition primitives.h:26
@ Or
Definition primitives.h:26
@ And
Definition primitives.h:26
@ Max
Definition primitives.h:26
@ Prod
Definition primitives.h:26
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
const Group & group() const
Definition primitives.h:16
DistPrimitive(Stream stream, Group group)
Definition primitives.h:13
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Recv(Stream stream, Group group, int src)
Definition primitives.h:121
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Send(Stream stream, Group group, int dst)
Definition primitives.h:102
std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
Definition distributed.h:9
Definition stream.h:9
A distributed::Group represents a group of independent mlx processes that can communicate.
Definition distributed.h:24