MLX
Loading...
Searching...
No Matches
ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <optional>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/stream.h"
10#include "mlx/utils.h"
11
12namespace mlx::core {
13
23 double start,
24 double stop,
25 double step,
26 Dtype dtype,
27 StreamOrDevice s = {});
28array arange(double start, double stop, double step, StreamOrDevice s = {});
29array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
30array arange(double start, double stop, StreamOrDevice s = {});
31array arange(double stop, Dtype dtype, StreamOrDevice s = {});
32array arange(double stop, StreamOrDevice s = {});
33
34array arange(int start, int stop, int step, StreamOrDevice s = {});
35array arange(int start, int stop, StreamOrDevice s = {});
36array arange(int stop, StreamOrDevice s = {});
37
40 double start,
41 double stop,
42 int num = 50,
43 Dtype dtype = float32,
44 StreamOrDevice s = {});
45
48
51 array a,
52 Shape shape,
53 Strides strides,
54 size_t offset,
55 StreamOrDevice s = {});
56
59
61array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});
62array full(Shape shape, array vals, StreamOrDevice s = {});
63template <typename T>
64array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) {
65 return full(std::move(shape), array(val, dtype), to_stream(s));
66}
67template <typename T>
68array full(Shape shape, T val, StreamOrDevice s = {}) {
69 return full(std::move(shape), array(val), to_stream(s));
70}
71
73array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
74inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
75 return zeros(shape, float32, s);
76}
78
80array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
81inline array ones(const Shape& shape, StreamOrDevice s = {}) {
82 return ones(shape, float32, s);
83}
85
88array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
89inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
90 return eye(n, n, 0, dtype, s);
91}
92inline array eye(int n, int m, StreamOrDevice s = {}) {
93 return eye(n, m, 0, float32, s);
94}
95inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
96 return eye(n, m, k, float32, s);
97}
98inline array eye(int n, StreamOrDevice s = {}) {
99 return eye(n, n, 0, float32, s);
100}
101
104array identity(int n, Dtype dtype, StreamOrDevice s = {});
105inline array identity(int n, StreamOrDevice s = {}) {
106 return identity(n, float32, s);
107}
108
109array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
110inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
111 return tri(n, n, 0, type, s);
112}
113
114array tril(array x, int k = 0, StreamOrDevice s = {});
115array triu(array x, int k = 0, StreamOrDevice s = {});
116
118array reshape(const array& a, Shape shape, StreamOrDevice s = {});
119
122 const array& a,
123 int start_axis,
124 int end_axis = -1,
125 StreamOrDevice s = {});
126
129
132 const array& a,
133 std::optional<float> scale = std::nullopt,
134 StreamOrDevice s = {});
135
138 const array& a,
139 const std::vector<int>& axes,
140 StreamOrDevice s = {});
141
143array squeeze(const array& a, int axis, StreamOrDevice s = {});
144
147
150 const array& a,
151 const std::vector<int>& axes,
152 StreamOrDevice s = {});
153
155array expand_dims(const array& a, int axis, StreamOrDevice s = {});
156
159 const array& a,
160 Shape start,
161 Shape stop,
162 Shape strides,
163 StreamOrDevice s = {});
164
166array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
167
170 const array& src,
171 const array& update,
172 Shape start,
173 Shape stop,
174 Shape strides,
175 StreamOrDevice s = {});
176
179 const array& src,
180 const array& update,
181 Shape start,
182 Shape stop,
183 StreamOrDevice s = {});
184
186std::vector<array>
187split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
188std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
189std::vector<array> split(
190 const array& a,
191 const std::vector<int>& indices,
192 int axis,
193 StreamOrDevice s = {});
194std::vector<array>
195split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
196
198std::vector<array> meshgrid(
199 const std::vector<array>& arrays,
200 bool sparse = false,
201 std::string indexing = "xy",
202 StreamOrDevice s = {});
203
208 const array& a,
209 const std::optional<array>& a_min = std::nullopt,
210 const std::optional<array>& a_max = std::nullopt,
211 StreamOrDevice s = {});
212
215 const std::vector<array>& arrays,
216 int axis,
217 StreamOrDevice s = {});
218array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
219
221array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
222array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
223
225array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
226array repeat(const array& arr, int repeats, StreamOrDevice s = {});
227
228array tile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});
229
231array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
233 const array& a,
234 std::initializer_list<int> axes,
235 StreamOrDevice s = {}) {
236 return transpose(a, std::vector<int>(axes), s);
237}
238
240array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
241
244 const array& a,
245 int source,
246 int destination,
247 StreamOrDevice s = {});
248
251 const array& a,
252 const std::vector<int>& axes,
253 const std::vector<int>& low_pad_size,
254 const std::vector<int>& high_pad_size,
255 const array& pad_value = array(0),
256 const std::string mode = "constant",
257 StreamOrDevice s = {});
258
261 const array& a,
262 const std::vector<std::pair<int, int>>& pad_width,
263 const array& pad_value = array(0),
264 const std::string mode = "constant",
265 StreamOrDevice s = {});
267 const array& a,
268 const std::pair<int, int>& pad_width,
269 const array& pad_value = array(0),
270 const std::string mode = "constant",
271 StreamOrDevice s = {});
273 const array& a,
274 int pad_width,
275 const array& pad_value = array(0),
276 const std::string mode = "constant",
277 StreamOrDevice s = {});
278
281
283array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {});
284
286std::vector<array> broadcast_arrays(
287 const std::vector<array>& inputs,
288 StreamOrDevice s = {});
289
291array equal(const array& a, const array& b, StreamOrDevice s = {});
292inline array operator==(const array& a, const array& b) {
293 return equal(a, b);
294}
295template <typename T>
296array operator==(T a, const array& b) {
297 return equal(array(a), b);
298}
299template <typename T>
300array operator==(const array& a, T b) {
301 return equal(a, array(b));
302}
303
305array not_equal(const array& a, const array& b, StreamOrDevice s = {});
306inline array operator!=(const array& a, const array& b) {
307 return not_equal(a, b);
308}
309template <typename T>
310array operator!=(T a, const array& b) {
311 return not_equal(array(a), b);
312}
313template <typename T>
314array operator!=(const array& a, T b) {
315 return not_equal(a, array(b));
316}
317
319array greater(const array& a, const array& b, StreamOrDevice s = {});
320inline array operator>(const array& a, const array& b) {
321 return greater(a, b);
322}
323template <typename T>
324array operator>(T a, const array& b) {
325 return greater(array(a), b);
326}
327template <typename T>
328array operator>(const array& a, T b) {
329 return greater(a, array(b));
330}
331
333array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
334inline array operator>=(const array& a, const array& b) {
335 return greater_equal(a, b);
336}
337template <typename T>
338array operator>=(T a, const array& b) {
339 return greater_equal(array(a), b);
340}
341template <typename T>
342array operator>=(const array& a, T b) {
343 return greater_equal(a, array(b));
344}
345
347array less(const array& a, const array& b, StreamOrDevice s = {});
348inline array operator<(const array& a, const array& b) {
349 return less(a, b);
350}
351template <typename T>
352array operator<(T a, const array& b) {
353 return less(array(a), b);
354}
355template <typename T>
356array operator<(const array& a, T b) {
357 return less(a, array(b));
358}
359
361array less_equal(const array& a, const array& b, StreamOrDevice s = {});
362inline array operator<=(const array& a, const array& b) {
363 return less_equal(a, b);
364}
365template <typename T>
366array operator<=(T a, const array& b) {
367 return less_equal(array(a), b);
368}
369template <typename T>
370array operator<=(const array& a, T b) {
371 return less_equal(a, array(b));
372}
373
376 const array& a,
377 const array& b,
378 bool equal_nan,
379 StreamOrDevice s = {});
380inline array
381array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
382 return array_equal(a, b, false, s);
383}
384
385array isnan(const array& a, StreamOrDevice s = {});
386
387array isinf(const array& a, StreamOrDevice s = {});
388
390
392
394
397 const array& condition,
398 const array& x,
399 const array& y,
400 StreamOrDevice s = {});
401
404 const array& a,
405 float nan = 0.0f,
406 const std::optional<float> posinf = std::nullopt,
407 const std::optional<float> neginf = std::nullopt,
408 StreamOrDevice s = {});
409
411array all(const array& a, bool keepdims, StreamOrDevice s = {});
412inline array all(const array& a, StreamOrDevice s = {}) {
413 return all(a, false, to_stream(s));
414}
415
418 const array& a,
419 const array& b,
420 double rtol = 1e-5,
421 double atol = 1e-8,
422 bool equal_nan = false,
423 StreamOrDevice s = {});
424
428 const array& a,
429 const array& b,
430 double rtol = 1e-5,
431 double atol = 1e-8,
432 bool equal_nan = false,
433 StreamOrDevice s = {});
434
440 const array& a,
441 const std::vector<int>& axes,
442 bool keepdims = false,
443 StreamOrDevice s = {});
444
450 const array& a,
451 int axis,
452 bool keepdims = false,
453 StreamOrDevice s = {});
454
456array any(const array& a, bool keepdims, StreamOrDevice s = {});
457inline array any(const array& a, StreamOrDevice s = {}) {
458 return any(a, false, to_stream(s));
459}
460
466 const array& a,
467 const std::vector<int>& axes,
468 bool keepdims = false,
469 StreamOrDevice s = {});
470
476 const array& a,
477 int axis,
478 bool keepdims = false,
479 StreamOrDevice s = {});
480
482array sum(const array& a, bool keepdims, StreamOrDevice s = {});
483inline array sum(const array& a, StreamOrDevice s = {}) {
484 return sum(a, false, to_stream(s));
485}
486
489 const array& a,
490 const std::vector<int>& axes,
491 bool keepdims = false,
492 StreamOrDevice s = {});
493
496 const array& a,
497 int axis,
498 bool keepdims = false,
499 StreamOrDevice s = {});
500
502array mean(const array& a, bool keepdims, StreamOrDevice s = {});
503inline array mean(const array& a, StreamOrDevice s = {}) {
504 return mean(a, false, to_stream(s));
505}
506
509 const array& a,
510 const std::vector<int>& axes,
511 bool keepdims = false,
512 StreamOrDevice s = {});
513
516 const array& a,
517 int axis,
518 bool keepdims = false,
519 StreamOrDevice s = {});
520
522array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
523inline array var(const array& a, StreamOrDevice s = {}) {
524 return var(a, false, 0, to_stream(s));
525}
526
530 const array& a,
531 const std::vector<int>& axes,
532 bool keepdims = false,
533 int ddof = 0,
534 StreamOrDevice s = {});
535
539 const array& a,
540 int axis,
541 bool keepdims = false,
542 int ddof = 0,
543 StreamOrDevice s = {});
544
546array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
547inline array std(const array& a, StreamOrDevice s = {}) {
548 return std(a, false, 0, to_stream(s));
549}
550
554 const array& a,
555 const std::vector<int>& axes,
556 bool keepdims = false,
557 int ddof = 0,
558 StreamOrDevice s = {});
559
563 const array& a,
564 int axis,
565 bool keepdims = false,
566 int ddof = 0,
567 StreamOrDevice s = {});
568
570array prod(const array& a, bool keepdims, StreamOrDevice s = {});
571inline array prod(const array& a, StreamOrDevice s = {}) {
572 return prod(a, false, to_stream(s));
573}
574
577 const array& a,
578 const std::vector<int>& axes,
579 bool keepdims = false,
580 StreamOrDevice s = {});
581
584 const array& a,
585 int axis,
586 bool keepdims = false,
587 StreamOrDevice s = {});
588
590array max(const array& a, bool keepdims, StreamOrDevice s = {});
591inline array max(const array& a, StreamOrDevice s = {}) {
592 return max(a, false, to_stream(s));
593}
594
597 const array& a,
598 const std::vector<int>& axes,
599 bool keepdims = false,
600 StreamOrDevice s = {});
601
604 const array& a,
605 int axis,
606 bool keepdims = false,
607 StreamOrDevice s = {});
608
610array min(const array& a, bool keepdims, StreamOrDevice s = {});
611inline array min(const array& a, StreamOrDevice s = {}) {
612 return min(a, false, to_stream(s));
613}
614
617 const array& a,
618 const std::vector<int>& axes,
619 bool keepdims = false,
620 StreamOrDevice s = {});
621
624 const array& a,
625 int axis,
626 bool keepdims = false,
627 StreamOrDevice s = {});
628
630array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
631inline array argmin(const array& a, StreamOrDevice s = {}) {
632 return argmin(a, false, s);
633}
634
637 const array& a,
638 int axis,
639 bool keepdims = false,
640 StreamOrDevice s = {});
641
643array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
644inline array argmax(const array& a, StreamOrDevice s = {}) {
645 return argmax(a, false, s);
646}
647
650 const array& a,
651 int axis,
652 bool keepdims = false,
653 StreamOrDevice s = {});
654
656array sort(const array& a, StreamOrDevice s = {});
657
659array sort(const array& a, int axis, StreamOrDevice s = {});
660
663
665array argsort(const array& a, int axis, StreamOrDevice s = {});
666
671array partition(const array& a, int kth, StreamOrDevice s = {});
672
677array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
678
683array argpartition(const array& a, int kth, StreamOrDevice s = {});
684
689array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
690
692array topk(const array& a, int k, StreamOrDevice s = {});
693
695array topk(const array& a, int k, int axis, StreamOrDevice s = {});
696
698array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
699inline array logsumexp(const array& a, StreamOrDevice s = {}) {
700 return logsumexp(a, false, to_stream(s));
701}
702
705 const array& a,
706 const std::vector<int>& axes,
707 bool keepdims = false,
708 StreamOrDevice s = {});
709
712 const array& a,
713 int axis,
714 bool keepdims = false,
715 StreamOrDevice s = {});
716
718array abs(const array& a, StreamOrDevice s = {});
719
723
725array sign(const array& a, StreamOrDevice s = {});
726
729
731array logical_and(const array& a, const array& b, StreamOrDevice s = {});
732array operator&&(const array& a, const array& b);
733
735array logical_or(const array& a, const array& b, StreamOrDevice s = {});
736array operator||(const array& a, const array& b);
737
740
742array add(const array& a, const array& b, StreamOrDevice s = {});
743array operator+(const array& a, const array& b);
744template <typename T>
745array operator+(T a, const array& b) {
746 return add(array(a), b);
747}
748template <typename T>
749array operator+(const array& a, T b) {
750 return add(a, array(b));
751}
752
754array subtract(const array& a, const array& b, StreamOrDevice s = {});
755array operator-(const array& a, const array& b);
756template <typename T>
757array operator-(T a, const array& b) {
758 return subtract(array(a), b);
759}
760template <typename T>
761array operator-(const array& a, T b) {
762 return subtract(a, array(b));
763}
764
766array multiply(const array& a, const array& b, StreamOrDevice s = {});
767array operator*(const array& a, const array& b);
768template <typename T>
769array operator*(T a, const array& b) {
770 return multiply(array(a), b);
771}
772template <typename T>
773array operator*(const array& a, T b) {
774 return multiply(a, array(b));
775}
776
778array divide(const array& a, const array& b, StreamOrDevice s = {});
779array operator/(const array& a, const array& b);
780array operator/(double a, const array& b);
781array operator/(const array& a, double b);
782
784std::vector<array>
785divmod(const array& a, const array& b, StreamOrDevice s = {});
786
788array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
789
791array remainder(const array& a, const array& b, StreamOrDevice s = {});
792array operator%(const array& a, const array& b);
793template <typename T>
794array operator%(T a, const array& b) {
795 return remainder(array(a), b);
796}
797template <typename T>
798array operator%(const array& a, T b) {
799 return remainder(a, array(b));
800}
801
803array maximum(const array& a, const array& b, StreamOrDevice s = {});
804
806array minimum(const array& a, const array& b, StreamOrDevice s = {});
807
809array floor(const array& a, StreamOrDevice s = {});
810
812array ceil(const array& a, StreamOrDevice s = {});
813
816
818array exp(const array& a, StreamOrDevice s = {});
819
821array sin(const array& a, StreamOrDevice s = {});
822
824array cos(const array& a, StreamOrDevice s = {});
825
827array tan(const array& a, StreamOrDevice s = {});
828
831
834
837
839array arctan2(const array& a, const array& b, StreamOrDevice s = {});
840
842array sinh(const array& a, StreamOrDevice s = {});
843
845array cosh(const array& a, StreamOrDevice s = {});
846
848array tanh(const array& a, StreamOrDevice s = {});
849
852
855
858
861
864
866array log(const array& a, StreamOrDevice s = {});
867
869array log2(const array& a, StreamOrDevice s = {});
870
872array log10(const array& a, StreamOrDevice s = {});
873
875array log1p(const array& a, StreamOrDevice s = {});
876
878array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
879
882
884array erf(const array& a, StreamOrDevice s = {});
885
888
890array expm1(const array& a, StreamOrDevice s = {});
891
894
896array round(const array& a, int decimals, StreamOrDevice s = {});
897inline array round(const array& a, StreamOrDevice s = {}) {
898 return round(a, 0, s);
899}
900
902array matmul(const array& a, const array& b, StreamOrDevice s = {});
903
906 const array& a,
907 const std::vector<array>& indices,
908 const std::vector<int>& axes,
909 const Shape& slice_sizes,
910 StreamOrDevice s = {});
912 const array& a,
913 const array& indices,
914 int axis,
915 const Shape& slice_sizes,
916 StreamOrDevice s = {}) {
917 return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
918}
919
922 const array& a,
923 const array& indices,
924 int axis,
925 StreamOrDevice s = {});
926array take(const array& a, int index, int axis, StreamOrDevice s = {});
927
929array take(const array& a, const array& indices, StreamOrDevice s = {});
930array take(const array& a, int index, StreamOrDevice s = {});
931
934 const array& a,
935 const array& indices,
936 int axis,
937 StreamOrDevice s = {});
938
941 const array& a,
942 const array& indices,
943 const array& values,
944 int axis,
945 StreamOrDevice s = {});
946
1046 const array& a,
1047 const std::vector<array>& indices,
1048 const array& updates,
1049 const std::vector<int>& axes,
1050 StreamOrDevice s = {});
1052 const array& a,
1053 const array& indices,
1054 const array& updates,
1055 int axis,
1056 StreamOrDevice s = {}) {
1057 return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
1058}
1059
1062 const array& a,
1063 const std::vector<array>& indices,
1064 const array& updates,
1065 const std::vector<int>& axes,
1066 StreamOrDevice s = {});
1068 const array& a,
1069 const array& indices,
1070 const array& updates,
1071 int axis,
1072 StreamOrDevice s = {}) {
1073 return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
1074}
1075
1078 const array& a,
1079 const std::vector<array>& indices,
1080 const array& updates,
1081 const std::vector<int>& axes,
1082 StreamOrDevice s = {});
1084 const array& a,
1085 const array& indices,
1086 const array& updates,
1087 int axis,
1088 StreamOrDevice s = {}) {
1089 return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
1090}
1091
1094 const array& a,
1095 const std::vector<array>& indices,
1096 const array& updates,
1097 const std::vector<int>& axes,
1098 StreamOrDevice s = {});
1100 const array& a,
1101 const array& indices,
1102 const array& updates,
1103 int axis,
1104 StreamOrDevice s = {}) {
1105 return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
1106}
1109 const array& a,
1110 const std::vector<array>& indices,
1111 const array& updates,
1112 const std::vector<int>& axes,
1113 StreamOrDevice s = {});
1115 const array& a,
1116 const array& indices,
1117 const array& updates,
1118 int axis,
1119 StreamOrDevice s = {}) {
1120 return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
1121}
1122
1124array sqrt(const array& a, StreamOrDevice s = {});
1125
1128
1131 const array& a,
1132 const std::vector<int>& axes,
1133 bool precise = false,
1134 StreamOrDevice s = {});
1135
1137array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
1138
1140inline array
1141softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
1142 return softmax(a, std::vector<int>{axis}, precise, s);
1143}
1144
1146array power(const array& a, const array& b, StreamOrDevice s = {});
1147
1150 const array& a,
1151 int axis,
1152 bool reverse = false,
1153 bool inclusive = true,
1154 StreamOrDevice s = {});
1155
1158 const array& a,
1159 int axis,
1160 bool reverse = false,
1161 bool inclusive = true,
1162 StreamOrDevice s = {});
1163
1166 const array& a,
1167 int axis,
1168 bool reverse = false,
1169 bool inclusive = true,
1170 StreamOrDevice s = {});
1171
1174 const array& a,
1175 int axis,
1176 bool reverse = false,
1177 bool inclusive = true,
1178 StreamOrDevice s = {});
1179
1182 array input,
1183 array weight,
1184 std::vector<int> stride = {},
1185 std::vector<int> padding_lo = {},
1186 std::vector<int> padding_hi = {},
1187 std::vector<int> kernel_dilation = {},
1188 std::vector<int> input_dilation = {},
1189 int groups = 1,
1190 bool flip = false,
1191 StreamOrDevice s = {});
1192
1195 const array& input,
1196 const array& weight,
1197 std::vector<int> stride = {},
1198 std::vector<int> padding = {},
1199 std::vector<int> kernel_dilation = {},
1200 std::vector<int> input_dilation = {},
1201 int groups = 1,
1202 bool flip = false,
1203 StreamOrDevice s = {}) {
1204 return conv_general(
1205 /* const array& input = */ input,
1206 /* const array& weight = */ weight,
1207 /* std::vector<int> stride = */ stride,
1208 /* std::vector<int> padding_lo = */ padding,
1209 /* std::vector<int> padding_hi = */ padding,
1210 /* std::vector<int> kernel_dilation = */ kernel_dilation,
1211 /* std::vector<int> input_dilation = */ input_dilation,
1212 /* int groups = */ groups,
1213 /* bool flip = */ flip,
1214 /* StreamOrDevice s = */ s);
1215}
1216
1219 const array& input,
1220 const array& weight,
1221 int stride = 1,
1222 int padding = 0,
1223 int dilation = 1,
1224 int groups = 1,
1225 StreamOrDevice s = {});
1226
1229 const array& input,
1230 const array& weight,
1231 const std::pair<int, int>& stride = {1, 1},
1232 const std::pair<int, int>& padding = {0, 0},
1233 const std::pair<int, int>& dilation = {1, 1},
1234 int groups = 1,
1235 StreamOrDevice s = {});
1236
1239 const array& input,
1240 const array& weight,
1241 const std::tuple<int, int, int>& stride = {1, 1, 1},
1242 const std::tuple<int, int, int>& padding = {0, 0, 0},
1243 const std::tuple<int, int, int>& dilation = {1, 1, 1},
1244 int groups = 1,
1245 StreamOrDevice s = {});
1246
1249 const array& input,
1250 const array& weight,
1251 int stride = 1,
1252 int padding = 0,
1253 int dilation = 1,
1254 int groups = 1,
1255 StreamOrDevice s = {});
1256
1259 const array& input,
1260 const array& weight,
1261 const std::pair<int, int>& stride = {1, 1},
1262 const std::pair<int, int>& padding = {0, 0},
1263 const std::pair<int, int>& dilation = {1, 1},
1264 int groups = 1,
1265 StreamOrDevice s = {});
1266
1269 const array& input,
1270 const array& weight,
1271 const std::tuple<int, int, int>& stride = {1, 1, 1},
1272 const std::tuple<int, int, int>& padding = {0, 0, 0},
1273 const std::tuple<int, int, int>& dilation = {1, 1, 1},
1274 int groups = 1,
1275 StreamOrDevice s = {});
1276
1279 array x,
1280 array w,
1281 array scales,
1282 array biases,
1283 bool transpose = true,
1284 int group_size = 64,
1285 int bits = 4,
1286 StreamOrDevice s = {});
1287
1289std::tuple<array, array, array> quantize(
1290 const array& w,
1291 int group_size = 64,
1292 int bits = 4,
1293 StreamOrDevice s = {});
1294
1297 const array& w,
1298 const array& scales,
1299 const array& biases,
1300 int group_size = 64,
1301 int bits = 4,
1302 StreamOrDevice s = {});
1303
1306 const array& x,
1307 const array& w,
1308 const array& scales,
1309 const array& biases,
1310 std::optional<array> lhs_indices = std::nullopt,
1311 std::optional<array> rhs_indices = std::nullopt,
1312 bool transpose = true,
1313 int group_size = 64,
1314 int bits = 4,
1315 StreamOrDevice s = {});
1316
1319 const array& a,
1320 const array& b,
1321 const int axis = 2,
1322 StreamOrDevice s = {});
1323
1325 const array& a,
1326 const array& b,
1327 const std::vector<int>& axes_a,
1328 const std::vector<int>& axes_b,
1329 StreamOrDevice s = {});
1330
1332array outer(const array& a, const array& b, StreamOrDevice s = {});
1333
1335array inner(const array& a, const array& b, StreamOrDevice s = {});
1336
1339 array c,
1340 array a,
1341 array b,
1342 const float& alpha = 1.f,
1343 const float& beta = 1.f,
1344 StreamOrDevice s = {});
1345
1348 array a,
1349 array b,
1350 int block_size,
1351 std::optional<array> mask_out = std::nullopt,
1352 std::optional<array> mask_lhs = std::nullopt,
1353 std::optional<array> mask_rhs = std::nullopt,
1354 StreamOrDevice s = {});
1355
1358 array a,
1359 array b,
1360 std::optional<array> lhs_indices = std::nullopt,
1361 std::optional<array> rhs_indices = std::nullopt,
1362 StreamOrDevice s = {});
1363
1366 const array& a,
1367 int offset = 0,
1368 int axis1 = 0,
1369 int axis2 = 1,
1370 StreamOrDevice s = {});
1371
1373array diag(const array& a, int k = 0, StreamOrDevice s = {});
1374
1377 const array& a,
1378 int offset,
1379 int axis1,
1380 int axis2,
1381 Dtype dtype,
1382 StreamOrDevice s = {});
1384 const array& a,
1385 int offset,
1386 int axis1,
1387 int axis2,
1388 StreamOrDevice s = {});
1390
1396std::vector<array> depends(
1397 const std::vector<array>& inputs,
1398 const std::vector<array>& dependencies);
1399
1402std::vector<array> atleast_1d(
1403 const std::vector<array>& a,
1404 StreamOrDevice s = {});
1406std::vector<array> atleast_2d(
1407 const std::vector<array>& a,
1408 StreamOrDevice s = {});
1410std::vector<array> atleast_3d(
1411 const std::vector<array>& a,
1412 StreamOrDevice s = {});
1413
1419 const array& a,
1420 std::vector<int> axes,
1421 bool inverted,
1422 Dtype dtype = int32,
1423 StreamOrDevice s = {});
1424
1426
1428array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
1429array operator&(const array& a, const array& b);
1430
1432array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
1433array operator|(const array& a, const array& b);
1434
1436array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
1437array operator^(const array& a, const array& b);
1438
1440array left_shift(const array& a, const array& b, StreamOrDevice s = {});
1441array operator<<(const array& a, const array& b);
1442
1444array right_shift(const array& a, const array& b, StreamOrDevice s = {});
1445array operator>>(const array& a, const array& b);
1446
1447array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
1448
1450array roll(const array& a, int shift, StreamOrDevice s = {});
1451array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
1452array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
1453array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
1454array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
1456 const array& a,
1457 const Shape& shift,
1458 const std::vector<int>& axes,
1459 StreamOrDevice s = {});
1460
1461/* The real part of a complex array. */
1462array real(const array& a, StreamOrDevice s = {});
1463
1464/* The imaginary part of a complex array. */
1465array imag(const array& a, StreamOrDevice s = {});
1466
1467/* Ensure the array's underlying memory is contiguous. */
1469 const array& a,
1470 bool allow_col_major = false,
1471 StreamOrDevice s = {});
1472
1475} // namespace mlx::core
Definition array.h:23
array scatter_max(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and max updates to given linear indices.
array floor_divide(const array &a, const array &b, StreamOrDevice s={})
Compute integer division.
array radians(const array &a, StreamOrDevice s={})
Convert the elements of an array from Degrees to Radians.
array reshape(const array &a, Shape shape, StreamOrDevice s={})
Reshape an array to the given shape.
array arccos(const array &a, StreamOrDevice s={})
Arc Cosine of the elements of an array.
array scatter_min(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and min updates to given linear indices.
array less_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a <= b) element-wise.
array cumprod(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative product of an array.
array astype(array a, Dtype dtype, StreamOrDevice s={})
Convert an array to the given data type.
array rsqrt(const array &a, StreamOrDevice s={})
Square root and reciprocal the elements of an array.
array diag(const array &a, int k=0, StreamOrDevice s={})
Extract diagonal from a 2d array or create a diagonal matrix.
array square(const array &a, StreamOrDevice s={})
Square the elements of an array.
array ceil(const array &a, StreamOrDevice s={})
Ceil the element of an array.
array log2(const array &a, StreamOrDevice s={})
Log base 2 of the elements of an array.
array clip(const array &a, const std::optional< array > &a_min=std::nullopt, const std::optional< array > &a_max=std::nullopt, StreamOrDevice s={})
Clip (limit) the values in an array.
array isnan(const array &a, StreamOrDevice s={})
array isneginf(const array &a, StreamOrDevice s={})
array subtract(const array &a, const array &b, StreamOrDevice s={})
Subtract two arrays.
array cummin(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative min of an array.
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with the given value(s).
array log10(const array &a, StreamOrDevice s={})
Log base 10 of the elements of an array.
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
array sign(const array &a, StreamOrDevice s={})
The sign of the elements in an array.
array cosh(const array &a, StreamOrDevice s={})
Hyperbolic Cosine of the elements of an array.
array conv_general(array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
General convolution with a filter.
array logical_or(const array &a, const array &b, StreamOrDevice s={})
Logical or of two arrays.
array moveaxis(const array &a, int source, int destination, StreamOrDevice s={})
Move an axis of an array.
array operator*(const array &a, const array &b)
array operator+(const array &a, const array &b)
array operator||(const array &a, const array &b)
array not_equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a != b) element-wise.
array erf(const array &a, StreamOrDevice s={})
Computes the error function of the elements of an array.
array slice(const array &a, Shape start, Shape stop, Shape strides, StreamOrDevice s={})
Slice an array.
array sqrt(const array &a, StreamOrDevice s={})
Square root the elements of an array.
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array add(const array &a, const array &b, StreamOrDevice s={})
Add two arrays.
array round(const array &a, int decimals, StreamOrDevice s={})
Round a floating point number.
array broadcast_to(const array &a, const Shape &shape, StreamOrDevice s={})
Broadcast an array to a given shape.
array conv1d(const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
1D convolution with a filter
array bitwise_xor(const array &a, const array &b, StreamOrDevice s={})
Bitwise exclusive or.
array equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a == b) element-wise.
array view(const array &a, const Dtype &dtype, StreamOrDevice s={})
array gather_qmm(const array &x, const array &w, const array &scales, const array &biases, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Compute matrix products with matrix-level gather.
array stop_gradient(const array &a, StreamOrDevice s={})
Stop the flow of gradients.
array scatter_prod(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and prod updates to given indices.
array cos(const array &a, StreamOrDevice s={})
Cosine of the elements of an array.
array operator>=(const array &a, const array &b)
Definition ops.h:334
array degrees(const array &a, StreamOrDevice s={})
Convert the elements of an array from Radians to Degrees.
array all(const array &a, bool keepdims, StreamOrDevice s={})
True if all elements in the array are true (or non-zero).
array tan(const array &a, StreamOrDevice s={})
Tangent of the elements of an array.
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere el...
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
array operator>>(const array &a, const array &b)
array minimum(const array &a, const array &b, StreamOrDevice s={})
Element-wise minimum between two arrays.
array prod(const array &a, bool keepdims, StreamOrDevice s={})
The product of all elements of the array.
array atleast_3d(const array &a, StreamOrDevice s={})
array operator<=(const array &a, const array &b)
Definition ops.h:362
array reciprocal(const array &a, StreamOrDevice s={})
The reciprocal (1/x) of the elements in an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array flatten(const array &a, int start_axis, int end_axis=-1, StreamOrDevice s={})
Flatten the dimensions in the range [start_axis, end_axis] .
array isclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
Returns a boolean array where two arrays are element-wise equal within the specified tolerance.
array operator|(const array &a, const array &b)
array topk(const array &a, int k, StreamOrDevice s={})
Returns topk elements of the flattened array.
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
array abs(const array &a, StreamOrDevice s={})
Absolute value of elements in an array.
std::vector< array > meshgrid(const std::vector< array > &arrays, bool sparse=false, std::string indexing="xy", StreamOrDevice s={})
A vector of coordinate arrays from coordinate vectors.
array conjugate(const array &a, StreamOrDevice s={})
array tanh(const array &a, StreamOrDevice s={})
Hyperbolic Tangent of the elements of an array.
array as_strided(array a, Shape shape, Strides strides, size_t offset, StreamOrDevice s={})
Create a view of an array with the given shape and strides.
array inner(const array &a, const array &b, StreamOrDevice s={})
Compute the inner product of two vectors.
array block_masked_mm(array a, array b, int block_size, std::optional< array > mask_out=std::nullopt, std::optional< array > mask_lhs=std::nullopt, std::optional< array > mask_rhs=std::nullopt, StreamOrDevice s={})
Compute matrix product with block masking.
array arctan2(const array &a, const array &b, StreamOrDevice s={})
Inverse tangent of the ratio of two arrays.
array number_of_elements(const array &a, std::vector< int > axes, bool inverted, Dtype dtype=int32, StreamOrDevice s={})
Extract the number of elements along some axes as a scalar array.
array conv3d(const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
3D convolution with a filter
array log(const array &a, StreamOrDevice s={})
Natural logarithm of the elements of an array.
array sigmoid(const array &a, StreamOrDevice s={})
Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).
array squeeze(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Remove singleton dimensions at the given axes.
array greater_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a >= b) element-wise.
array expand_dims(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Add a singleton dimension at the given axes.
array isfinite(const array &a, StreamOrDevice s={})
array conv2d(const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
2D convolution with a filter
array operator>(const array &a, const array &b)
Definition ops.h:320
array bitwise_and(const array &a, const array &b, StreamOrDevice s={})
Bitwise and.
std::vector< array > split(const array &a, int num_splits, int axis, StreamOrDevice s={})
Split an array into sub-arrays along a given axis.
array matmul(const array &a, const array &b, StreamOrDevice s={})
Matrix-matrix multiplication.
array logical_and(const array &a, const array &b, StreamOrDevice s={})
Logical and of two arrays.
array erfinv(const array &a, StreamOrDevice s={})
Computes the inverse error function of the elements of an array.
array divide(const array &a, const array &b, StreamOrDevice s={})
Divide two arrays.
array power(const array &a, const array &b, StreamOrDevice s={})
Raise elements of a to the power of b element-wise.
array maximum(const array &a, const array &b, StreamOrDevice s={})
Element-wise maximum between two arrays.
array slice_update(const array &src, const array &update, Shape start, Shape stop, Shape strides, StreamOrDevice s={})
Update a slice from the source array.
array argmin(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the minimum value in the array.
array var(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the variance of the elements of an array.
array softmax(const array &a, const std::vector< int > &axes, bool precise=false, StreamOrDevice s={})
Softmax of an array.
array sort(const array &a, StreamOrDevice s={})
Returns a sorted copy of the flattened array.
array max(const array &a, bool keepdims, StreamOrDevice s={})
The maximum of all elements of the array.
array imag(const array &a, StreamOrDevice s={})
array pad(const array &a, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size, const array &pad_value=array(0), const std::string mode="constant", StreamOrDevice s={})
Pad an array with a constant value.
array addmm(array c, array a, array b, const float &alpha=1.f, const float &beta=1.f, StreamOrDevice s={})
Compute D = beta * C + alpha * (A @ B)
array tril(array x, int k=0, StreamOrDevice s={})
array any(const array &a, bool keepdims, StreamOrDevice s={})
True if any elements in the array are true (or non-zero).
array outer(const array &a, const array &b, StreamOrDevice s={})
Compute the outer product of two vectors.
array hadamard_transform(const array &a, std::optional< float > scale=std::nullopt, StreamOrDevice s={})
Multiply the array by the Hadamard matrix of corresponding size.
array arcsin(const array &a, StreamOrDevice s={})
Arc Sine of the elements of an array.
array left_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the left.
array where(const array &condition, const array &x, const array &y, StreamOrDevice s={})
Select from x or y depending on condition.
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
array contiguous(const array &a, bool allow_col_major=false, StreamOrDevice s={})
array bitwise_or(const array &a, const array &b, StreamOrDevice s={})
Bitwise inclusive or.
array gather_mm(array a, array b, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, StreamOrDevice s={})
Compute matrix product with matrix-level gather.
array floor(const array &a, StreamOrDevice s={})
Floor the element of an array.
array conv_transpose3d(const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
3D transposed convolution with a filter
array argsort(const array &a, StreamOrDevice s={})
Returns indices that sort the flattened array.
array put_along_axis(const array &a, const array &indices, const array &values, int axis, StreamOrDevice s={})
Put the values into the array at the given indices along the axis.
array array_equal(const array &a, const array &b, bool equal_nan, StreamOrDevice s={})
True if two arrays have the same shape and elements.
array isinf(const array &a, StreamOrDevice s={})
array gather(const array &a, const std::vector< array > &indices, const std::vector< int > &axes, const Shape &slice_sizes, StreamOrDevice s={})
Gather array entries given indices and slices.
array less(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a < b) element-wise.
array diagonal(const array &a, int offset=0, int axis1=0, int axis2=1, StreamOrDevice s={})
Extract a diagonal or construct a diagonal array.
array ones_like(const array &a, StreamOrDevice s={})
array negative(const array &a, StreamOrDevice s={})
Negate an array.
array linspace(double start, double stop, int num=50, Dtype dtype=float32, StreamOrDevice s={})
A 1D array of num evenly spaced numbers in the range [start, stop]
array remainder(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise remainder of division.
array arctan(const array &a, StreamOrDevice s={})
Arc Tangent of the elements of an array.
array conv_transpose1d(const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
1D transposed convolution with a filter
std::vector< array > divmod(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise quotient and remainder.
array triu(array x, int k=0, StreamOrDevice s={})
array arccosh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Cosine of the elements of an array.
array tile(const array &arr, std::vector< int > reps, StreamOrDevice s={})
array nan_to_num(const array &a, float nan=0.0f, const std::optional< float > posinf=std::nullopt, const std::optional< float > neginf=std::nullopt, StreamOrDevice s={})
Replace NaN and infinities with finite numbers.
array min(const array &a, bool keepdims, StreamOrDevice s={})
The minimum of all elements of the array.
array operator%(const array &a, const array &b)
std::tuple< array, array, array > quantize(const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
Quantize a matrix along its last axis.
array arctanh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Tangent of the elements of an array.
array repeat(const array &arr, int repeats, int axis, StreamOrDevice s={})
Repeat an array along an axis.
std::vector< array > broadcast_arrays(const std::vector< array > &inputs, StreamOrDevice s={})
Broadcast a vector of arrays against one another.
array atleast_1d(const array &a, StreamOrDevice s={})
convert an array to an atleast ndim array
array swapaxes(const array &a, int axis1, int axis2, StreamOrDevice s={})
Swap two axes of an array.
array logical_not(const array &a, StreamOrDevice s={})
Logical not of an array.
array trace(const array &a, int offset, int axis1, int axis2, Dtype dtype, StreamOrDevice s={})
Return the sum along a specified diagonal in the given array.
array quantized_matmul(array x, array w, array scales, array biases, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Quantized matmul multiplies x with a quantized matrix w.
array dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
Dequantize a matrix produced by quantize()
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array partition(const array &a, int kth, StreamOrDevice s={})
Returns a partitioned copy of the flattened array such that the smaller kth elements are first.
array take(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array slices at the given indices of the specified axis.
array operator^(const array &a, const array &b)
array roll(const array &a, int shift, StreamOrDevice s={})
Roll elements along an axis and introduce them on the other side.
std::vector< array > depends(const std::vector< array > &inputs, const std::vector< array > &dependencies)
Implements the identity function but allows injecting dependencies to other arrays.
array arcsinh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Sine of the elements of an array.
array scatter_add(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and add updates to given indices.
array logsumexp(const array &a, bool keepdims, StreamOrDevice s={})
The logsumexp of all elements of the array.
array scatter(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter updates to the given indices.
array operator<<(const array &a, const array &b)
array isposinf(const array &a, StreamOrDevice s={})
array cumsum(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative sum of an array.
array operator-(const array &a)
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
array ones(const Shape &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with ones.
array take_along_axis(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array entries given indices along the axis.
array zeros(const Shape &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with zeros.
array argmax(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the maximum value in the array.
array conv_transpose2d(const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
2D transposed convolution with a filter
array sin(const array &a, StreamOrDevice s={})
Sine of the elements of an array.
array operator&&(const array &a, const array &b)
array cummax(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative max of an array.
array operator<(const array &a, const array &b)
Definition ops.h:348
array atleast_2d(const array &a, StreamOrDevice s={})
array operator/(const array &a, const array &b)
array allclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
True if the two arrays are equal within the specified tolerance.
array operator&(const array &a, const array &b)
array argpartition(const array &a, int kth, StreamOrDevice s={})
Returns indices that partition the flattened array such that the smaller kth elements are first.
array greater(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a > b) element-wise.
array sinh(const array &a, StreamOrDevice s={})
Hyperbolic Sine of the elements of an array.
array multiply(const array &a, const array &b, StreamOrDevice s={})
Multiply two arrays.
array tensordot(const array &a, const array &b, const int axis=2, StreamOrDevice s={})
Returns a contraction of a and b over multiple dimensions.
array real(const array &a, StreamOrDevice s={})
array stack(const std::vector< array > &arrays, int axis, StreamOrDevice s={})
Stack arrays along a new axis.
array logaddexp(const array &a, const array &b, StreamOrDevice s={})
Log-add-exp of one elements in the array: log(exp(a) + exp(b)).
array right_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the right.
array zeros_like(const array &a, StreamOrDevice s={})
Definition allocator.h:7
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition arange.h:24
Stream to_stream(StreamOrDevice s)
void copy(const array &src, array &dst, CopyType ctype)
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
void concatenate(std::string &acc, T first)
Definition utils.h:66
bool operator==(const Device &lhs, const Device &rhs)
bool operator!=(const Device &lhs, const Device &rhs)
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14
std::vector< int32_t > Shape
Definition array.h:20
std::vector< size_t > Strides
Definition array.h:21
Definition dtype.h:13