15 void eval_gpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
17 throw std::runtime_error(
18 "Communication primitives cannot be run on the GPU");
36 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
38 std::pair<std::vector<array>, std::vector<int>>
vmap(
39 const std::vector<array>& inputs,
40 const std::vector<int>& axes)
override;
41 std::vector<array>
jvp(
42 const std::vector<array>& primals,
43 const std::vector<array>& tangents,
44 const std::vector<int>& argnums)
override;
45 std::vector<array>
vjp(
46 const std::vector<array>& primals,
47 const std::vector<array>& cotangents,
48 const std::vector<int>& argnums,
49 const std::vector<array>& outputs)
override;
51 void print(std::ostream& os)
override {
52 switch (reduce_type_) {
82 void eval_cpu(
const std::vector<array>& inputs, std::vector<array>& outputs)
84 std::pair<std::vector<array>, std::vector<int>>
vmap(
85 const std::vector<array>& inputs,
86 const std::vector<int>& axes)
override;
87 std::vector<array>
jvp(
88 const std::vector<array>& primals,
89 const std::vector<array>& tangents,
90 const std::vector<int>& argnums)
override;
91 std::vector<array>
vjp(
92 const std::vector<array>& primals,
93 const std::vector<array>& cotangents,
94 const std::vector<int>& argnums,
95 const std::vector<array>& outputs)
override;
Definition primitives.h:48
Definition primitives.h:78
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
AllGather(Group group)
Definition primitives.h:80
std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
Definition primitives.h:29
std::pair< std::vector< array >, std::vector< int > > vmap(const std::vector< array > &inputs, const std::vector< int > &axes) override
The primitive must know how to vectorize itself across the given axes.
void print(std::ostream &os) override
Print the primitive.
Definition primitives.h:51
AllReduce(Group group, ReduceType reduce_type)
Definition primitives.h:33
ReduceType
Definition primitives.h:31
@ Sum
Definition primitives.h:31
@ Min
Definition primitives.h:31
@ Or
Definition primitives.h:31
@ And
Definition primitives.h:31
@ Max
Definition primitives.h:31
@ Prod
Definition primitives.h:31
std::vector< array > vjp(const std::vector< array > &primals, const std::vector< array > &cotangents, const std::vector< int > &argnums, const std::vector< array > &outputs) override
The vector-Jacobian product.
void eval_cpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
A primitive must know how to evaluate itself on the CPU/GPU for the given inputs and populate the out...
std::vector< array > jvp(const std::vector< array > &primals, const std::vector< array > &tangents, const std::vector< int > &argnums) override
The Jacobian-vector product.
Definition primitives.h:10
void eval_gpu(const std::vector< array > &inputs, std::vector< array > &outputs) override
Definition primitives.h:15
const Group & group() const
Definition primitives.h:21
DistPrimitive(Group group)
Definition primitives.h:12
Definition distributed.h:9
A distributed::Group represents a group of independent mlx processes that can communicate.
Definition distributed.h:19