MLX
 
Loading...
Searching...
No Matches
reduce.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
6
7namespace mlx::core {
8
10 // Self-explanatory. Read everything and produce 1 output.
12
13 // The input is contiguous and the last axis is reduced
14 // N1xR1xN2xR2x...xNnxRn
16
17 // The input is contiguous and the last axis is not reduced
18 // R1xN1xR2xN2x...xRnxNn
20
21 // The input is not contiguous but the last axis is and it is reduced so we
22 // need to figure out the offsets but we can call the contiguous reduce after
23 // that.
24 // N3xR1xN1xR4x...xRn
26
27 // The input is not contiguous but the last reduction axis and the last axis
28 // are so we need to figure out the offset but we can call the strided reduce
29 // after that.
31
32 // The input is not contiguous after the reduction axis and it may contain
33 // 0-stride axes or transpositions. We could copy the strides and produce a
34 // transposed outcome or we can read the input out of order and write the
35 // output in order.
37};
38
43
44 ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
45 : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
47};
48
49ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
50
51std::pair<Shape, Strides> shapes_without_reduction_axes(
52 const array& x,
53 const std::vector<int>& axes);
54
55} // namespace mlx::core
Definition array.h:24
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Definition allocator.h:7
std::pair< Shape, Strides > shapes_without_reduction_axes(const array &x, const std::vector< int > &axes)
ReductionOpType
Definition reduce.h:9
@ GeneralReduce
Definition reduce.h:36
@ GeneralContiguousReduce
Definition reduce.h:25
@ ContiguousStridedReduce
Definition reduce.h:19
@ ContiguousReduce
Definition reduce.h:15
@ GeneralStridedReduce
Definition reduce.h:30
@ ContiguousAllReduce
Definition reduce.h:11
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
ReductionPlan get_reduction_plan(const array &x, const std::vector< int > &axes)
Definition reduce.h:39
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
Definition reduce.h:44
Shape shape
Definition reduce.h:41
ReductionOpType type
Definition reduce.h:40
Strides strides
Definition reduce.h:42
ReductionPlan(ReductionOpType type_)
Definition reduce.h:46