MLX
Loading...
Searching...
No Matches
ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <optional>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/stream.h"
10#include "mlx/utils.h"
11
12namespace mlx::core {
13
23 double start,
24 double stop,
25 double step,
26 Dtype dtype,
27 StreamOrDevice s = {});
28array arange(double start, double stop, double step, StreamOrDevice s = {});
29array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
30array arange(double start, double stop, StreamOrDevice s = {});
31array arange(double stop, Dtype dtype, StreamOrDevice s = {});
32array arange(double stop, StreamOrDevice s = {});
33
34array arange(int start, int stop, int step, StreamOrDevice s = {});
35array arange(int start, int stop, StreamOrDevice s = {});
36array arange(int stop, StreamOrDevice s = {});
37
40 double start,
41 double stop,
42 int num = 50,
43 Dtype dtype = float32,
44 StreamOrDevice s = {});
45
48
51 array a,
52 std::vector<int> shape,
53 std::vector<size_t> strides,
54 size_t offset,
55 StreamOrDevice s = {});
56
59
62 std::vector<int> shape,
63 array vals,
64 Dtype dtype,
65 StreamOrDevice s = {});
66array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
67template <typename T>
68array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
69 return full(std::move(shape), array(val, dtype), to_stream(s));
70}
71template <typename T>
72array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
73 return full(std::move(shape), array(val), to_stream(s));
74}
75
77array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
78inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
79 return zeros(shape, float32, s);
80}
82
84array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
85inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
86 return ones(shape, float32, s);
87}
89
92array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
93inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
94 return eye(n, n, 0, dtype, s);
95}
96inline array eye(int n, int m, StreamOrDevice s = {}) {
97 return eye(n, m, 0, float32, s);
98}
99inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
100 return eye(n, m, k, float32, s);
101}
102inline array eye(int n, StreamOrDevice s = {}) {
103 return eye(n, n, 0, float32, s);
104}
105
108array identity(int n, Dtype dtype, StreamOrDevice s = {});
109inline array identity(int n, StreamOrDevice s = {}) {
110 return identity(n, float32, s);
111}
112
113array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
114inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
115 return tri(n, n, 0, type, s);
116}
117
118array tril(array x, int k = 0, StreamOrDevice s = {});
119array triu(array x, int k = 0, StreamOrDevice s = {});
120
122array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
123
126 const array& a,
127 int start_axis,
128 int end_axis = -1,
129 StreamOrDevice s = {});
130
133
136 const array& a,
137 std::optional<float> scale = std::nullopt,
138 StreamOrDevice s = {});
139
142 const array& a,
143 const std::vector<int>& axes,
144 StreamOrDevice s = {});
145
147array squeeze(const array& a, int axis, StreamOrDevice s = {});
148
151
154 const array& a,
155 const std::vector<int>& axes,
156 StreamOrDevice s = {});
157
159array expand_dims(const array& a, int axis, StreamOrDevice s = {});
160
163 const array& a,
164 std::vector<int> start,
165 std::vector<int> stop,
166 std::vector<int> strides,
167 StreamOrDevice s = {});
168
171 const array& a,
172 std::vector<int> start,
173 std::vector<int> stop,
174 StreamOrDevice s = {});
175
178 const array& src,
179 const array& update,
180 std::vector<int> start,
181 std::vector<int> stop,
182 std::vector<int> strides,
183 StreamOrDevice s = {});
184
187 const array& src,
188 const array& update,
189 std::vector<int> start,
190 std::vector<int> stop,
191 StreamOrDevice s = {});
192
194std::vector<array>
195split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
196std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
197std::vector<array> split(
198 const array& a,
199 const std::vector<int>& indices,
200 int axis,
201 StreamOrDevice s = {});
202std::vector<array>
203split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
204
206std::vector<array> meshgrid(
207 const std::vector<array>& arrays,
208 bool sparse = false,
209 std::string indexing = "xy",
210 StreamOrDevice s = {});
211
216 const array& a,
217 const std::optional<array>& a_min = std::nullopt,
218 const std::optional<array>& a_max = std::nullopt,
219 StreamOrDevice s = {});
220
223 const std::vector<array>& arrays,
224 int axis,
225 StreamOrDevice s = {});
226array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
227
229array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
230array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
231
233array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
234array repeat(const array& arr, int repeats, StreamOrDevice s = {});
235
236array tile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});
237
239array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
241 const array& a,
242 std::initializer_list<int> axes,
243 StreamOrDevice s = {}) {
244 return transpose(a, std::vector<int>(axes), s);
245}
246
248array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
249
252 const array& a,
253 int source,
254 int destination,
255 StreamOrDevice s = {});
256
259 const array& a,
260 const std::vector<int>& axes,
261 const std::vector<int>& low_pad_size,
262 const std::vector<int>& high_pad_size,
263 const array& pad_value = array(0),
264 const std::string mode = "constant",
265 StreamOrDevice s = {});
266
269 const array& a,
270 const std::vector<std::pair<int, int>>& pad_width,
271 const array& pad_value = array(0),
272 const std::string mode = "constant",
273 StreamOrDevice s = {});
275 const array& a,
276 const std::pair<int, int>& pad_width,
277 const array& pad_value = array(0),
278 const std::string mode = "constant",
279 StreamOrDevice s = {});
281 const array& a,
282 int pad_width,
283 const array& pad_value = array(0),
284 const std::string mode = "constant",
285 StreamOrDevice s = {});
286
289
292 const array& a,
293 const std::vector<int>& shape,
294 StreamOrDevice s = {});
295
297std::vector<array> broadcast_arrays(
298 const std::vector<array>& inputs,
299 StreamOrDevice s = {});
300
302array equal(const array& a, const array& b, StreamOrDevice s = {});
303inline array operator==(const array& a, const array& b) {
304 return equal(a, b);
305}
306template <typename T>
307array operator==(T a, const array& b) {
308 return equal(array(a), b);
309}
310template <typename T>
311array operator==(const array& a, T b) {
312 return equal(a, array(b));
313}
314
316array not_equal(const array& a, const array& b, StreamOrDevice s = {});
317inline array operator!=(const array& a, const array& b) {
318 return not_equal(a, b);
319}
320template <typename T>
321array operator!=(T a, const array& b) {
322 return not_equal(array(a), b);
323}
324template <typename T>
325array operator!=(const array& a, T b) {
326 return not_equal(a, array(b));
327}
328
330array greater(const array& a, const array& b, StreamOrDevice s = {});
331inline array operator>(const array& a, const array& b) {
332 return greater(a, b);
333}
334template <typename T>
335array operator>(T a, const array& b) {
336 return greater(array(a), b);
337}
338template <typename T>
339array operator>(const array& a, T b) {
340 return greater(a, array(b));
341}
342
344array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
345inline array operator>=(const array& a, const array& b) {
346 return greater_equal(a, b);
347}
348template <typename T>
349array operator>=(T a, const array& b) {
350 return greater_equal(array(a), b);
351}
352template <typename T>
353array operator>=(const array& a, T b) {
354 return greater_equal(a, array(b));
355}
356
358array less(const array& a, const array& b, StreamOrDevice s = {});
359inline array operator<(const array& a, const array& b) {
360 return less(a, b);
361}
362template <typename T>
363array operator<(T a, const array& b) {
364 return less(array(a), b);
365}
366template <typename T>
367array operator<(const array& a, T b) {
368 return less(a, array(b));
369}
370
372array less_equal(const array& a, const array& b, StreamOrDevice s = {});
373inline array operator<=(const array& a, const array& b) {
374 return less_equal(a, b);
375}
376template <typename T>
377array operator<=(T a, const array& b) {
378 return less_equal(array(a), b);
379}
380template <typename T>
381array operator<=(const array& a, T b) {
382 return less_equal(a, array(b));
383}
384
387 const array& a,
388 const array& b,
389 bool equal_nan,
390 StreamOrDevice s = {});
391inline array
392array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
393 return array_equal(a, b, false, s);
394}
395
396array isnan(const array& a, StreamOrDevice s = {});
397
398array isinf(const array& a, StreamOrDevice s = {});
399
401
403
405
408 const array& condition,
409 const array& x,
410 const array& y,
411 StreamOrDevice s = {});
412
415 const array& a,
416 float nan = 0.0f,
417 const std::optional<float> posinf = std::nullopt,
418 const std::optional<float> neginf = std::nullopt,
419 StreamOrDevice s = {});
420
422array all(const array& a, bool keepdims, StreamOrDevice s = {});
423inline array all(const array& a, StreamOrDevice s = {}) {
424 return all(a, false, to_stream(s));
425}
426
429 const array& a,
430 const array& b,
431 double rtol = 1e-5,
432 double atol = 1e-8,
433 bool equal_nan = false,
434 StreamOrDevice s = {});
435
439 const array& a,
440 const array& b,
441 double rtol = 1e-5,
442 double atol = 1e-8,
443 bool equal_nan = false,
444 StreamOrDevice s = {});
445
451 const array& a,
452 const std::vector<int>& axes,
453 bool keepdims = false,
454 StreamOrDevice s = {});
455
461 const array& a,
462 int axis,
463 bool keepdims = false,
464 StreamOrDevice s = {});
465
467array any(const array& a, bool keepdims, StreamOrDevice s = {});
468inline array any(const array& a, StreamOrDevice s = {}) {
469 return any(a, false, to_stream(s));
470}
471
477 const array& a,
478 const std::vector<int>& axes,
479 bool keepdims = false,
480 StreamOrDevice s = {});
481
487 const array& a,
488 int axis,
489 bool keepdims = false,
490 StreamOrDevice s = {});
491
493array sum(const array& a, bool keepdims, StreamOrDevice s = {});
494inline array sum(const array& a, StreamOrDevice s = {}) {
495 return sum(a, false, to_stream(s));
496}
497
500 const array& a,
501 const std::vector<int>& axes,
502 bool keepdims = false,
503 StreamOrDevice s = {});
504
507 const array& a,
508 int axis,
509 bool keepdims = false,
510 StreamOrDevice s = {});
511
513array mean(const array& a, bool keepdims, StreamOrDevice s = {});
514inline array mean(const array& a, StreamOrDevice s = {}) {
515 return mean(a, false, to_stream(s));
516}
517
520 const array& a,
521 const std::vector<int>& axes,
522 bool keepdims = false,
523 StreamOrDevice s = {});
524
527 const array& a,
528 int axis,
529 bool keepdims = false,
530 StreamOrDevice s = {});
531
533array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
534inline array var(const array& a, StreamOrDevice s = {}) {
535 return var(a, false, 0, to_stream(s));
536}
537
541 const array& a,
542 const std::vector<int>& axes,
543 bool keepdims = false,
544 int ddof = 0,
545 StreamOrDevice s = {});
546
550 const array& a,
551 int axis,
552 bool keepdims = false,
553 int ddof = 0,
554 StreamOrDevice s = {});
555
557array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
558inline array std(const array& a, StreamOrDevice s = {}) {
559 return std(a, false, 0, to_stream(s));
560}
561
565 const array& a,
566 const std::vector<int>& axes,
567 bool keepdims = false,
568 int ddof = 0,
569 StreamOrDevice s = {});
570
574 const array& a,
575 int axis,
576 bool keepdims = false,
577 int ddof = 0,
578 StreamOrDevice s = {});
579
581array prod(const array& a, bool keepdims, StreamOrDevice s = {});
582inline array prod(const array& a, StreamOrDevice s = {}) {
583 return prod(a, false, to_stream(s));
584}
585
588 const array& a,
589 const std::vector<int>& axes,
590 bool keepdims = false,
591 StreamOrDevice s = {});
592
595 const array& a,
596 int axis,
597 bool keepdims = false,
598 StreamOrDevice s = {});
599
601array max(const array& a, bool keepdims, StreamOrDevice s = {});
602inline array max(const array& a, StreamOrDevice s = {}) {
603 return max(a, false, to_stream(s));
604}
605
608 const array& a,
609 const std::vector<int>& axes,
610 bool keepdims = false,
611 StreamOrDevice s = {});
612
615 const array& a,
616 int axis,
617 bool keepdims = false,
618 StreamOrDevice s = {});
619
621array min(const array& a, bool keepdims, StreamOrDevice s = {});
622inline array min(const array& a, StreamOrDevice s = {}) {
623 return min(a, false, to_stream(s));
624}
625
628 const array& a,
629 const std::vector<int>& axes,
630 bool keepdims = false,
631 StreamOrDevice s = {});
632
635 const array& a,
636 int axis,
637 bool keepdims = false,
638 StreamOrDevice s = {});
639
641array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
642inline array argmin(const array& a, StreamOrDevice s = {}) {
643 return argmin(a, false, s);
644}
645
648 const array& a,
649 int axis,
650 bool keepdims = false,
651 StreamOrDevice s = {});
652
654array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
655inline array argmax(const array& a, StreamOrDevice s = {}) {
656 return argmax(a, false, s);
657}
658
661 const array& a,
662 int axis,
663 bool keepdims = false,
664 StreamOrDevice s = {});
665
667array sort(const array& a, StreamOrDevice s = {});
668
670array sort(const array& a, int axis, StreamOrDevice s = {});
671
674
676array argsort(const array& a, int axis, StreamOrDevice s = {});
677
682array partition(const array& a, int kth, StreamOrDevice s = {});
683
688array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
689
694array argpartition(const array& a, int kth, StreamOrDevice s = {});
695
700array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
701
703array topk(const array& a, int k, StreamOrDevice s = {});
704
706array topk(const array& a, int k, int axis, StreamOrDevice s = {});
707
709array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
710inline array logsumexp(const array& a, StreamOrDevice s = {}) {
711 return logsumexp(a, false, to_stream(s));
712}
713
716 const array& a,
717 const std::vector<int>& axes,
718 bool keepdims = false,
719 StreamOrDevice s = {});
720
723 const array& a,
724 int axis,
725 bool keepdims = false,
726 StreamOrDevice s = {});
727
729array abs(const array& a, StreamOrDevice s = {});
730
734
736array sign(const array& a, StreamOrDevice s = {});
737
740
742array logical_and(const array& a, const array& b, StreamOrDevice s = {});
743array operator&&(const array& a, const array& b);
744
746array logical_or(const array& a, const array& b, StreamOrDevice s = {});
747array operator||(const array& a, const array& b);
748
751
753array add(const array& a, const array& b, StreamOrDevice s = {});
754array operator+(const array& a, const array& b);
755template <typename T>
756array operator+(T a, const array& b) {
757 return add(array(a), b);
758}
759template <typename T>
760array operator+(const array& a, T b) {
761 return add(a, array(b));
762}
763
765array subtract(const array& a, const array& b, StreamOrDevice s = {});
766array operator-(const array& a, const array& b);
767template <typename T>
768array operator-(T a, const array& b) {
769 return subtract(array(a), b);
770}
771template <typename T>
772array operator-(const array& a, T b) {
773 return subtract(a, array(b));
774}
775
777array multiply(const array& a, const array& b, StreamOrDevice s = {});
778array operator*(const array& a, const array& b);
779template <typename T>
780array operator*(T a, const array& b) {
781 return multiply(array(a), b);
782}
783template <typename T>
784array operator*(const array& a, T b) {
785 return multiply(a, array(b));
786}
787
789array divide(const array& a, const array& b, StreamOrDevice s = {});
790array operator/(const array& a, const array& b);
791array operator/(double a, const array& b);
792array operator/(const array& a, double b);
793
795std::vector<array>
796divmod(const array& a, const array& b, StreamOrDevice s = {});
797
799array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
800
802array remainder(const array& a, const array& b, StreamOrDevice s = {});
803array operator%(const array& a, const array& b);
804template <typename T>
805array operator%(T a, const array& b) {
806 return remainder(array(a), b);
807}
808template <typename T>
809array operator%(const array& a, T b) {
810 return remainder(a, array(b));
811}
812
814array maximum(const array& a, const array& b, StreamOrDevice s = {});
815
817array minimum(const array& a, const array& b, StreamOrDevice s = {});
818
820array floor(const array& a, StreamOrDevice s = {});
821
823array ceil(const array& a, StreamOrDevice s = {});
824
827
829array exp(const array& a, StreamOrDevice s = {});
830
832array sin(const array& a, StreamOrDevice s = {});
833
835array cos(const array& a, StreamOrDevice s = {});
836
838array tan(const array& a, StreamOrDevice s = {});
839
842
845
848
850array arctan2(const array& a, const array& b, StreamOrDevice s = {});
851
853array sinh(const array& a, StreamOrDevice s = {});
854
856array cosh(const array& a, StreamOrDevice s = {});
857
859array tanh(const array& a, StreamOrDevice s = {});
860
863
866
869
872
875
877array log(const array& a, StreamOrDevice s = {});
878
880array log2(const array& a, StreamOrDevice s = {});
881
883array log10(const array& a, StreamOrDevice s = {});
884
886array log1p(const array& a, StreamOrDevice s = {});
887
889array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
890
893
895array erf(const array& a, StreamOrDevice s = {});
896
899
901array expm1(const array& a, StreamOrDevice s = {});
902
905
907array round(const array& a, int decimals, StreamOrDevice s = {});
908inline array round(const array& a, StreamOrDevice s = {}) {
909 return round(a, 0, s);
910}
911
913array matmul(const array& a, const array& b, StreamOrDevice s = {});
914
917 const array& a,
918 const std::vector<array>& indices,
919 const std::vector<int>& axes,
920 const std::vector<int>& slice_sizes,
921 StreamOrDevice s = {});
923 const array& a,
924 const array& indices,
925 int axis,
926 const std::vector<int>& slice_sizes,
927 StreamOrDevice s = {}) {
928 return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
929}
930
933 const array& a,
934 const array& indices,
935 int axis,
936 StreamOrDevice s = {});
937array take(const array& a, int index, int axis, StreamOrDevice s = {});
938
940array take(const array& a, const array& indices, StreamOrDevice s = {});
941array take(const array& a, int index, StreamOrDevice s = {});
942
945 const array& a,
946 const array& indices,
947 int axis,
948 StreamOrDevice s = {});
949
952 const array& a,
953 const array& indices,
954 const array& values,
955 int axis,
956 StreamOrDevice s = {});
957
1057 const array& a,
1058 const std::vector<array>& indices,
1059 const array& updates,
1060 const std::vector<int>& axes,
1061 StreamOrDevice s = {});
1063 const array& a,
1064 const array& indices,
1065 const array& updates,
1066 int axis,
1067 StreamOrDevice s = {}) {
1068 return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
1069}
1070
1073 const array& a,
1074 const std::vector<array>& indices,
1075 const array& updates,
1076 const std::vector<int>& axes,
1077 StreamOrDevice s = {});
1079 const array& a,
1080 const array& indices,
1081 const array& updates,
1082 int axis,
1083 StreamOrDevice s = {}) {
1084 return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
1085}
1086
1089 const array& a,
1090 const std::vector<array>& indices,
1091 const array& updates,
1092 const std::vector<int>& axes,
1093 StreamOrDevice s = {});
1095 const array& a,
1096 const array& indices,
1097 const array& updates,
1098 int axis,
1099 StreamOrDevice s = {}) {
1100 return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
1101}
1102
1105 const array& a,
1106 const std::vector<array>& indices,
1107 const array& updates,
1108 const std::vector<int>& axes,
1109 StreamOrDevice s = {});
1111 const array& a,
1112 const array& indices,
1113 const array& updates,
1114 int axis,
1115 StreamOrDevice s = {}) {
1116 return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
1117}
1120 const array& a,
1121 const std::vector<array>& indices,
1122 const array& updates,
1123 const std::vector<int>& axes,
1124 StreamOrDevice s = {});
1126 const array& a,
1127 const array& indices,
1128 const array& updates,
1129 int axis,
1130 StreamOrDevice s = {}) {
1131 return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
1132}
1133
1135array sqrt(const array& a, StreamOrDevice s = {});
1136
1139
1142 const array& a,
1143 const std::vector<int>& axes,
1144 bool precise = false,
1145 StreamOrDevice s = {});
1146
1148array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
1149
1151inline array
1152softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
1153 return softmax(a, std::vector<int>{axis}, precise, s);
1154}
1155
1157array power(const array& a, const array& b, StreamOrDevice s = {});
1158
1161 const array& a,
1162 int axis,
1163 bool reverse = false,
1164 bool inclusive = true,
1165 StreamOrDevice s = {});
1166
1169 const array& a,
1170 int axis,
1171 bool reverse = false,
1172 bool inclusive = true,
1173 StreamOrDevice s = {});
1174
1177 const array& a,
1178 int axis,
1179 bool reverse = false,
1180 bool inclusive = true,
1181 StreamOrDevice s = {});
1182
1185 const array& a,
1186 int axis,
1187 bool reverse = false,
1188 bool inclusive = true,
1189 StreamOrDevice s = {});
1190
1193 array input,
1194 array weight,
1195 std::vector<int> stride = {},
1196 std::vector<int> padding_lo = {},
1197 std::vector<int> padding_hi = {},
1198 std::vector<int> kernel_dilation = {},
1199 std::vector<int> input_dilation = {},
1200 int groups = 1,
1201 bool flip = false,
1202 StreamOrDevice s = {});
1203
1206 const array& input,
1207 const array& weight,
1208 std::vector<int> stride = {},
1209 std::vector<int> padding = {},
1210 std::vector<int> kernel_dilation = {},
1211 std::vector<int> input_dilation = {},
1212 int groups = 1,
1213 bool flip = false,
1214 StreamOrDevice s = {}) {
1215 return conv_general(
1216 /* const array& input = */ input,
1217 /* const array& weight = */ weight,
1218 /* std::vector<int> stride = */ stride,
1219 /* std::vector<int> padding_lo = */ padding,
1220 /* std::vector<int> padding_hi = */ padding,
1221 /* std::vector<int> kernel_dilation = */ kernel_dilation,
1222 /* std::vector<int> input_dilation = */ input_dilation,
1223 /* int groups = */ groups,
1224 /* bool flip = */ flip,
1225 /* StreamOrDevice s = */ s);
1226}
1227
1230 const array& input,
1231 const array& weight,
1232 int stride = 1,
1233 int padding = 0,
1234 int dilation = 1,
1235 int groups = 1,
1236 StreamOrDevice s = {});
1237
1240 const array& input,
1241 const array& weight,
1242 const std::pair<int, int>& stride = {1, 1},
1243 const std::pair<int, int>& padding = {0, 0},
1244 const std::pair<int, int>& dilation = {1, 1},
1245 int groups = 1,
1246 StreamOrDevice s = {});
1247
1250 const array& input,
1251 const array& weight,
1252 const std::tuple<int, int, int>& stride = {1, 1, 1},
1253 const std::tuple<int, int, int>& padding = {0, 0, 0},
1254 const std::tuple<int, int, int>& dilation = {1, 1, 1},
1255 int groups = 1,
1256 StreamOrDevice s = {});
1257
1260 const array& input,
1261 const array& weight,
1262 int stride = 1,
1263 int padding = 0,
1264 int dilation = 1,
1265 int groups = 1,
1266 StreamOrDevice s = {});
1267
1270 const array& input,
1271 const array& weight,
1272 const std::pair<int, int>& stride = {1, 1},
1273 const std::pair<int, int>& padding = {0, 0},
1274 const std::pair<int, int>& dilation = {1, 1},
1275 int groups = 1,
1276 StreamOrDevice s = {});
1277
1280 const array& input,
1281 const array& weight,
1282 const std::tuple<int, int, int>& stride = {1, 1, 1},
1283 const std::tuple<int, int, int>& padding = {0, 0, 0},
1284 const std::tuple<int, int, int>& dilation = {1, 1, 1},
1285 int groups = 1,
1286 StreamOrDevice s = {});
1287
1290 array x,
1291 array w,
1292 array scales,
1293 array biases,
1294 bool transpose = true,
1295 int group_size = 64,
1296 int bits = 4,
1297 StreamOrDevice s = {});
1298
1300std::tuple<array, array, array> quantize(
1301 const array& w,
1302 int group_size = 64,
1303 int bits = 4,
1304 StreamOrDevice s = {});
1305
1308 const array& w,
1309 const array& scales,
1310 const array& biases,
1311 int group_size = 64,
1312 int bits = 4,
1313 StreamOrDevice s = {});
1314
1317 const array& x,
1318 const array& w,
1319 const array& scales,
1320 const array& biases,
1321 std::optional<array> lhs_indices = std::nullopt,
1322 std::optional<array> rhs_indices = std::nullopt,
1323 bool transpose = true,
1324 int group_size = 64,
1325 int bits = 4,
1326 StreamOrDevice s = {});
1327
1330 const array& a,
1331 const array& b,
1332 const int axis = 2,
1333 StreamOrDevice s = {});
1334
1336 const array& a,
1337 const array& b,
1338 const std::vector<int>& axes_a,
1339 const std::vector<int>& axes_b,
1340 StreamOrDevice s = {});
1341
1343array outer(const array& a, const array& b, StreamOrDevice s = {});
1344
1346array inner(const array& a, const array& b, StreamOrDevice s = {});
1347
1350 array c,
1351 array a,
1352 array b,
1353 const float& alpha = 1.f,
1354 const float& beta = 1.f,
1355 StreamOrDevice s = {});
1356
1359 array a,
1360 array b,
1361 int block_size,
1362 std::optional<array> mask_out = std::nullopt,
1363 std::optional<array> mask_lhs = std::nullopt,
1364 std::optional<array> mask_rhs = std::nullopt,
1365 StreamOrDevice s = {});
1366
1369 array a,
1370 array b,
1371 std::optional<array> lhs_indices = std::nullopt,
1372 std::optional<array> rhs_indices = std::nullopt,
1373 StreamOrDevice s = {});
1374
1377 const array& a,
1378 int offset = 0,
1379 int axis1 = 0,
1380 int axis2 = 1,
1381 StreamOrDevice s = {});
1382
1384array diag(const array& a, int k = 0, StreamOrDevice s = {});
1385
1388 const array& a,
1389 int offset,
1390 int axis1,
1391 int axis2,
1392 Dtype dtype,
1393 StreamOrDevice s = {});
1395 const array& a,
1396 int offset,
1397 int axis1,
1398 int axis2,
1399 StreamOrDevice s = {});
1401
1407std::vector<array> depends(
1408 const std::vector<array>& inputs,
1409 const std::vector<array>& dependencies);
1410
1413std::vector<array> atleast_1d(
1414 const std::vector<array>& a,
1415 StreamOrDevice s = {});
1417std::vector<array> atleast_2d(
1418 const std::vector<array>& a,
1419 StreamOrDevice s = {});
1421std::vector<array> atleast_3d(
1422 const std::vector<array>& a,
1423 StreamOrDevice s = {});
1424
1430 const array& a,
1431 std::vector<int> axes,
1432 bool inverted,
1433 Dtype dtype = int32,
1434 StreamOrDevice s = {});
1435
1437
1439array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
1440array operator&(const array& a, const array& b);
1441
1443array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
1444array operator|(const array& a, const array& b);
1445
1447array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
1448array operator^(const array& a, const array& b);
1449
1451array left_shift(const array& a, const array& b, StreamOrDevice s = {});
1452array operator<<(const array& a, const array& b);
1453
1455array right_shift(const array& a, const array& b, StreamOrDevice s = {});
1456array operator>>(const array& a, const array& b);
1457
1458array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
1459
1461array roll(const array& a, int shift, StreamOrDevice s = {});
1463 const array& a,
1464 const std::vector<int>& shift,
1465 StreamOrDevice s = {});
1466array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
1468 const array& a,
1469 int shift,
1470 const std::vector<int>& axes,
1471 StreamOrDevice s = {});
1473 const array& a,
1474 const std::vector<int>& shift,
1475 int axis,
1476 StreamOrDevice s = {});
1478 const array& a,
1479 const std::vector<int>& shift,
1480 const std::vector<int>& axes,
1481 StreamOrDevice s = {});
1482
1483/* The real part of a complex array. */
1484array real(const array& a, StreamOrDevice s = {});
1485
1486/* The imaginary part of a complex array. */
1487array imag(const array& a, StreamOrDevice s = {});
1488
1491} // namespace mlx::core
Definition array.h:20
array scatter_max(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and max updates to given linear indices.
array floor_divide(const array &a, const array &b, StreamOrDevice s={})
Compute integer division.
array radians(const array &a, StreamOrDevice s={})
Convert the elements of an array from Degrees to Radians.
array arccos(const array &a, StreamOrDevice s={})
Arc Cosine of the elements of an array.
array scatter_min(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and min updates to given linear indices.
array less_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a <= b) element-wise.
array cumprod(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative product of an array.
array astype(array a, Dtype dtype, StreamOrDevice s={})
Convert an array to the given data type.
array rsqrt(const array &a, StreamOrDevice s={})
Square root and reciprocal the elements of an array.
array diag(const array &a, int k=0, StreamOrDevice s={})
Extract diagonal from a 2d array or create a diagonal matrix.
array square(const array &a, StreamOrDevice s={})
Square the elements of an array.
array ceil(const array &a, StreamOrDevice s={})
Ceil the element of an array.
array log2(const array &a, StreamOrDevice s={})
Log base 2 of the elements of an array.
array clip(const array &a, const std::optional< array > &a_min=std::nullopt, const std::optional< array > &a_max=std::nullopt, StreamOrDevice s={})
Clip (limit) the values in an array.
array isnan(const array &a, StreamOrDevice s={})
array isneginf(const array &a, StreamOrDevice s={})
array subtract(const array &a, const array &b, StreamOrDevice s={})
Subtract two arrays.
array cummin(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative min of an array.
array log10(const array &a, StreamOrDevice s={})
Log base 10 of the elements of an array.
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
array sign(const array &a, StreamOrDevice s={})
The sign of the elements in an array.
array cosh(const array &a, StreamOrDevice s={})
Hyperbolic Cosine of the elements of an array.
array conv_general(array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
General convolution with a filter.
array logical_or(const array &a, const array &b, StreamOrDevice s={})
Logical or of two arrays.
array moveaxis(const array &a, int source, int destination, StreamOrDevice s={})
Move an axis of an array.
array operator*(const array &a, const array &b)
array operator+(const array &a, const array &b)
array operator||(const array &a, const array &b)
array not_equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a != b) element-wise.
array erf(const array &a, StreamOrDevice s={})
Computes the error function of the elements of an array.
array sqrt(const array &a, StreamOrDevice s={})
Square root the elements of an array.
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array add(const array &a, const array &b, StreamOrDevice s={})
Add two arrays.
array round(const array &a, int decimals, StreamOrDevice s={})
Round a floating point number.
array conv1d(const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
1D convolution with a filter
array bitwise_xor(const array &a, const array &b, StreamOrDevice s={})
Bitwise exclusive or.
array equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a == b) element-wise.
array zeros(const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with zeros.
array view(const array &a, const Dtype &dtype, StreamOrDevice s={})
array gather_qmm(const array &x, const array &w, const array &scales, const array &biases, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Compute matrix products with matrix-level gather.
array stop_gradient(const array &a, StreamOrDevice s={})
Stop the flow of gradients.
array scatter_prod(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and prod updates to given indices.
array slice_update(const array &src, const array &update, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
Update a slice from the source array.
array cos(const array &a, StreamOrDevice s={})
Cosine of the elements of an array.
array operator>=(const array &a, const array &b)
Definition ops.h:345
array degrees(const array &a, StreamOrDevice s={})
Convert the elements of an array from Radians to Degrees.
array all(const array &a, bool keepdims, StreamOrDevice s={})
True if all elements in the array are true (or non-zero).
array tan(const array &a, StreamOrDevice s={})
Tangent of the elements of an array.
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere el...
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
array operator>>(const array &a, const array &b)
array minimum(const array &a, const array &b, StreamOrDevice s={})
Element-wise minimum between two arrays.
array prod(const array &a, bool keepdims, StreamOrDevice s={})
The product of all elements of the array.
array atleast_3d(const array &a, StreamOrDevice s={})
array operator<=(const array &a, const array &b)
Definition ops.h:373
array reciprocal(const array &a, StreamOrDevice s={})
The reciprocal (1/x) of the elements in an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array flatten(const array &a, int start_axis, int end_axis=-1, StreamOrDevice s={})
Flatten the dimensions in the range [start_axis, end_axis] .
array isclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
Returns a boolean array where two arrays are element-wise equal within the specified tolerance.
array operator|(const array &a, const array &b)
array topk(const array &a, int k, StreamOrDevice s={})
Returns topk elements of the flattened array.
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
array ones(const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with ones.
array abs(const array &a, StreamOrDevice s={})
Absolute value of elements in an array.
std::vector< array > meshgrid(const std::vector< array > &arrays, bool sparse=false, std::string indexing="xy", StreamOrDevice s={})
A vector of coordinate arrays from coordinate vectors.
array conjugate(const array &a, StreamOrDevice s={})
array tanh(const array &a, StreamOrDevice s={})
Hyperbolic Tangent of the elements of an array.
array inner(const array &a, const array &b, StreamOrDevice s={})
Compute the inner product of two vectors.
array block_masked_mm(array a, array b, int block_size, std::optional< array > mask_out=std::nullopt, std::optional< array > mask_lhs=std::nullopt, std::optional< array > mask_rhs=std::nullopt, StreamOrDevice s={})
Compute matrix product with block masking.
array arctan2(const array &a, const array &b, StreamOrDevice s={})
Inverse tangent of the ratio of two arrays.
array number_of_elements(const array &a, std::vector< int > axes, bool inverted, Dtype dtype=int32, StreamOrDevice s={})
Extract the number of elements along some axes as a scalar array.
array conv3d(const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
3D convolution with a filter
array log(const array &a, StreamOrDevice s={})
Natural logarithm of the elements of an array.
array sigmoid(const array &a, StreamOrDevice s={})
Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).
array squeeze(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Remove singleton dimensions at the given axes.
array greater_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a >= b) element-wise.
array expand_dims(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Add a singleton dimension at the given axes.
array isfinite(const array &a, StreamOrDevice s={})
array conv2d(const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
2D convolution with a filter
array operator>(const array &a, const array &b)
Definition ops.h:331
array bitwise_and(const array &a, const array &b, StreamOrDevice s={})
Bitwise and.
std::vector< array > split(const array &a, int num_splits, int axis, StreamOrDevice s={})
Split an array into sub-arrays along a given axis.
array matmul(const array &a, const array &b, StreamOrDevice s={})
Matrix-matrix multiplication.
array logical_and(const array &a, const array &b, StreamOrDevice s={})
Logical and of two arrays.
array erfinv(const array &a, StreamOrDevice s={})
Computes the inverse error function of the elements of an array.
array divide(const array &a, const array &b, StreamOrDevice s={})
Divide two arrays.
array power(const array &a, const array &b, StreamOrDevice s={})
Raise elements of a to the power of b element-wise.
array maximum(const array &a, const array &b, StreamOrDevice s={})
Element-wise maximum between two arrays.
array reshape(const array &a, std::vector< int > shape, StreamOrDevice s={})
Reshape an array to the given shape.
array argmin(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the minimum value in the array.
array var(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the variance of the elements of an array.
array full(std::vector< int > shape, array vals, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with the given value(s).
array softmax(const array &a, const std::vector< int > &axes, bool precise=false, StreamOrDevice s={})
Softmax of an array.
array sort(const array &a, StreamOrDevice s={})
Returns a sorted copy of the flattened array.
array max(const array &a, bool keepdims, StreamOrDevice s={})
The maximum of all elements of the array.
array imag(const array &a, StreamOrDevice s={})
array pad(const array &a, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size, const array &pad_value=array(0), const std::string mode="constant", StreamOrDevice s={})
Pad an array with a constant value.
array addmm(array c, array a, array b, const float &alpha=1.f, const float &beta=1.f, StreamOrDevice s={})
Compute D = beta * C + alpha * (A @ B)
array tril(array x, int k=0, StreamOrDevice s={})
array any(const array &a, bool keepdims, StreamOrDevice s={})
True if any elements in the array are true (or non-zero).
array outer(const array &a, const array &b, StreamOrDevice s={})
Compute the outer product of two vectors.
array hadamard_transform(const array &a, std::optional< float > scale=std::nullopt, StreamOrDevice s={})
Multiply the array by the Hadamard matrix of corresponding size.
array arcsin(const array &a, StreamOrDevice s={})
Arc Sine of the elements of an array.
array left_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the left.
array where(const array &condition, const array &x, const array &y, StreamOrDevice s={})
Select from x or y depending on condition.
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
array bitwise_or(const array &a, const array &b, StreamOrDevice s={})
Bitwise inclusive or.
array gather_mm(array a, array b, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, StreamOrDevice s={})
Compute matrix product with matrix-level gather.
array floor(const array &a, StreamOrDevice s={})
Floor the element of an array.
array conv_transpose3d(const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
3D transposed convolution with a filter
array as_strided(array a, std::vector< int > shape, std::vector< size_t > strides, size_t offset, StreamOrDevice s={})
Create a view of an array with the given shape and strides.
array argsort(const array &a, StreamOrDevice s={})
Returns indices that sort the flattened array.
array put_along_axis(const array &a, const array &indices, const array &values, int axis, StreamOrDevice s={})
Put the values into the array at the given indices along the axis.
array array_equal(const array &a, const array &b, bool equal_nan, StreamOrDevice s={})
True if two arrays have the same shape and elements.
array isinf(const array &a, StreamOrDevice s={})
array less(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a < b) element-wise.
array diagonal(const array &a, int offset=0, int axis1=0, int axis2=1, StreamOrDevice s={})
Extract a diagonal or construct a diagonal array.
array ones_like(const array &a, StreamOrDevice s={})
array negative(const array &a, StreamOrDevice s={})
Negate an array.
array linspace(double start, double stop, int num=50, Dtype dtype=float32, StreamOrDevice s={})
A 1D array of num evenly spaced numbers in the range [start, stop]
array remainder(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise remainder of division.
array arctan(const array &a, StreamOrDevice s={})
Arc Tangent of the elements of an array.
array conv_transpose1d(const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
1D transposed convolution with a filter
std::vector< array > divmod(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise quotient and remainder.
array triu(array x, int k=0, StreamOrDevice s={})
array arccosh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Cosine of the elements of an array.
array tile(const array &arr, std::vector< int > reps, StreamOrDevice s={})
array nan_to_num(const array &a, float nan=0.0f, const std::optional< float > posinf=std::nullopt, const std::optional< float > neginf=std::nullopt, StreamOrDevice s={})
Replace NaN and infinities with finite numbers.
array min(const array &a, bool keepdims, StreamOrDevice s={})
The minimum of all elements of the array.
array operator%(const array &a, const array &b)
std::tuple< array, array, array > quantize(const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
Quantize a matrix along its last axis.
array arctanh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Tangent of the elements of an array.
array repeat(const array &arr, int repeats, int axis, StreamOrDevice s={})
Repeat an array along an axis.
array gather(const array &a, const std::vector< array > &indices, const std::vector< int > &axes, const std::vector< int > &slice_sizes, StreamOrDevice s={})
Gather array entries given indices and slices.
std::vector< array > broadcast_arrays(const std::vector< array > &inputs, StreamOrDevice s={})
Broadcast a vector of arrays against one another.
array atleast_1d(const array &a, StreamOrDevice s={})
convert an array to an atleast ndim array
array swapaxes(const array &a, int axis1, int axis2, StreamOrDevice s={})
Swap two axes of an array.
array logical_not(const array &a, StreamOrDevice s={})
Logical not of an array.
array concatenate(const std::vector< array > &arrays, int axis, StreamOrDevice s={})
Concatenate arrays along a given axis.
array trace(const array &a, int offset, int axis1, int axis2, Dtype dtype, StreamOrDevice s={})
Return the sum along a specified diagonal in the given array.
array quantized_matmul(array x, array w, array scales, array biases, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Quantized matmul multiplies x with a quantized matrix w.
array dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
Dequantize a matrix produced by quantize()
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array partition(const array &a, int kth, StreamOrDevice s={})
Returns a partitioned copy of the flattened array such that the smaller kth elements are first.
array take(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array slices at the given indices of the specified axis.
array operator^(const array &a, const array &b)
array roll(const array &a, int shift, StreamOrDevice s={})
Roll elements along an axis and introduce them on the other side.
std::vector< array > depends(const std::vector< array > &inputs, const std::vector< array > &dependencies)
Implements the identity function but allows injecting dependencies to other arrays.
array arcsinh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Sine of the elements of an array.
array scatter_add(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and add updates to given indices.
array logsumexp(const array &a, bool keepdims, StreamOrDevice s={})
The logsumexp of all elements of the array.
array broadcast_to(const array &a, const std::vector< int > &shape, StreamOrDevice s={})
Broadcast an array to a given shape.
array scatter(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter updates to the given indices.
array operator<<(const array &a, const array &b)
array slice(const array &a, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
Slice an array.
array isposinf(const array &a, StreamOrDevice s={})
array cumsum(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative sum of an array.
array operator-(const array &a)
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
array take_along_axis(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array entries given indices along the axis.
array argmax(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the maximum value in the array.
array conv_transpose2d(const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
2D transposed convolution with a filter
array sin(const array &a, StreamOrDevice s={})
Sine of the elements of an array.
array operator&&(const array &a, const array &b)
array cummax(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative max of an array.
array operator<(const array &a, const array &b)
Definition ops.h:359
array atleast_2d(const array &a, StreamOrDevice s={})
array operator/(const array &a, const array &b)
array allclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
True if the two arrays are equal within the specified tolerance.
array operator&(const array &a, const array &b)
array argpartition(const array &a, int kth, StreamOrDevice s={})
Returns indices that partition the flattened array such that the smaller kth elements are first.
array greater(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a > b) element-wise.
array sinh(const array &a, StreamOrDevice s={})
Hyperbolic Sine of the elements of an array.
array multiply(const array &a, const array &b, StreamOrDevice s={})
Multiply two arrays.
array tensordot(const array &a, const array &b, const int axis=2, StreamOrDevice s={})
Returns a contraction of a and b over multiple dimensions.
array real(const array &a, StreamOrDevice s={})
array stack(const std::vector< array > &arrays, int axis, StreamOrDevice s={})
Stack arrays along a new axis.
array logaddexp(const array &a, const array &b, StreamOrDevice s={})
Log-add-exp of one elements in the array: log(exp(a) + exp(b)).
array right_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the right.
array zeros_like(const array &a, StreamOrDevice s={})
Definition allocator.h:7
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition arange.h:24
Stream to_stream(StreamOrDevice s)
void copy(const array &src, array &dst, CopyType ctype)
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
bool operator==(const Device &lhs, const Device &rhs)
bool operator!=(const Device &lhs, const Device &rhs)
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14
Definition dtype.h:13