MLX
 
Loading...
Searching...
No Matches
ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <optional>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/stream.h"
10#include "mlx/utils.h"
11
12namespace mlx::core {
13
18
23 double start,
24 double stop,
25 double step,
26 Dtype dtype,
27 StreamOrDevice s = {});
28array arange(double start, double stop, double step, StreamOrDevice s = {});
29array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
30array arange(double start, double stop, StreamOrDevice s = {});
31array arange(double stop, Dtype dtype, StreamOrDevice s = {});
32array arange(double stop, StreamOrDevice s = {});
33
34array arange(int start, int stop, int step, StreamOrDevice s = {});
35array arange(int start, int stop, StreamOrDevice s = {});
36array arange(int stop, StreamOrDevice s = {});
37
40 double start,
41 double stop,
42 int num = 50,
43 Dtype dtype = float32,
44 StreamOrDevice s = {});
45
48
51 array a,
52 Shape shape,
53 Strides strides,
54 size_t offset,
55 StreamOrDevice s = {});
56
59
61array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});
62array full(Shape shape, array vals, StreamOrDevice s = {});
63template <typename T>
64array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) {
65 return full(std::move(shape), array(val, dtype), to_stream(s));
66}
67template <typename T>
68array full(Shape shape, T val, StreamOrDevice s = {}) {
69 return full(std::move(shape), array(val), to_stream(s));
70}
71
73array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
74inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
75 return zeros(shape, float32, s);
76}
78
80array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
81inline array ones(const Shape& shape, StreamOrDevice s = {}) {
82 return ones(shape, float32, s);
83}
85
88array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
89inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
90 return eye(n, n, 0, dtype, s);
91}
92inline array eye(int n, int m, StreamOrDevice s = {}) {
93 return eye(n, m, 0, float32, s);
94}
95inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
96 return eye(n, m, k, float32, s);
97}
98inline array eye(int n, StreamOrDevice s = {}) {
99 return eye(n, n, 0, float32, s);
100}
101
104array identity(int n, Dtype dtype, StreamOrDevice s = {});
105inline array identity(int n, StreamOrDevice s = {}) {
106 return identity(n, float32, s);
107}
108
109array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
110inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
111 return tri(n, n, 0, type, s);
112}
113
114array tril(array x, int k = 0, StreamOrDevice s = {});
115array triu(array x, int k = 0, StreamOrDevice s = {});
116
118array reshape(const array& a, Shape shape, StreamOrDevice s = {});
119
121array unflatten(const array& a, int axis, Shape shape, StreamOrDevice s = {});
122
125 const array& a,
126 int start_axis,
127 int end_axis = -1,
128 StreamOrDevice s = {});
129
132
135 const array& a,
136 std::optional<float> scale = std::nullopt,
137 StreamOrDevice s = {});
138
141 const array& a,
142 const std::vector<int>& axes,
143 StreamOrDevice s = {});
144
146array squeeze(const array& a, int axis, StreamOrDevice s = {});
147
150
153 const array& a,
154 const std::vector<int>& axes,
155 StreamOrDevice s = {});
156
158array expand_dims(const array& a, int axis, StreamOrDevice s = {});
159
162 const array& a,
163 Shape start,
164 Shape stop,
165 Shape strides,
166 StreamOrDevice s = {});
168 const array& a,
169 std::initializer_list<int> start,
170 Shape stop,
171 Shape strides,
172 StreamOrDevice s = {}) {
173 return slice(a, Shape(start), std::move(stop), std::move(strides), s);
174}
175
177array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
178
181 const array& a,
182 const array& start,
183 std::vector<int> axes,
184 Shape slice_size,
185 StreamOrDevice s = {});
186
189 const array& src,
190 const array& update,
191 Shape start,
192 Shape stop,
193 Shape strides,
194 StreamOrDevice s = {});
195
198 const array& src,
199 const array& update,
200 Shape start,
201 Shape stop,
202 StreamOrDevice s = {});
203
206 const array& src,
207 const array& update,
208 const array& start,
209 std::vector<int> axes,
210 StreamOrDevice s = {});
211
213std::vector<array>
214split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
215std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
216std::vector<array>
217split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {});
218std::vector<array>
219split(const array& a, const Shape& indices, StreamOrDevice s = {});
220
222std::vector<array> meshgrid(
223 const std::vector<array>& arrays,
224 bool sparse = false,
225 const std::string& indexing = "xy",
226 StreamOrDevice s = {});
227
232 const array& a,
233 const std::optional<array>& a_min = std::nullopt,
234 const std::optional<array>& a_max = std::nullopt,
235 StreamOrDevice s = {});
236
238array concatenate(std::vector<array> arrays, int axis, StreamOrDevice s = {});
239array concatenate(std::vector<array> arrays, StreamOrDevice s = {});
240
242array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
243array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
244
246array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
247array repeat(const array& arr, int repeats, StreamOrDevice s = {});
248
249array tile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});
250
252array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
254 const array& a,
255 std::initializer_list<int> axes,
256 StreamOrDevice s = {}) {
257 return transpose(a, std::vector<int>(axes), s);
258}
259
261array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
262
265 const array& a,
266 int source,
267 int destination,
268 StreamOrDevice s = {});
269
272 const array& a,
273 const std::vector<int>& axes,
274 const Shape& low_pad_size,
275 const Shape& high_pad_size,
276 const array& pad_value = array(0),
277 const std::string& mode = "constant",
278 StreamOrDevice s = {});
279
282 const array& a,
283 const std::vector<std::pair<int, int>>& pad_width,
284 const array& pad_value = array(0),
285 const std::string& mode = "constant",
286 StreamOrDevice s = {});
288 const array& a,
289 const std::pair<int, int>& pad_width,
290 const array& pad_value = array(0),
291 const std::string& mode = "constant",
292 StreamOrDevice s = {});
294 const array& a,
295 int pad_width,
296 const array& pad_value = array(0),
297 const std::string& mode = "constant",
298 StreamOrDevice s = {});
299
302
304array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {});
305
307std::vector<array> broadcast_arrays(
308 const std::vector<array>& inputs,
309 StreamOrDevice s = {});
310
312array equal(const array& a, const array& b, StreamOrDevice s = {});
313inline array operator==(const array& a, const array& b) {
314 return equal(a, b);
315}
316template <typename T>
317array operator==(T a, const array& b) {
318 return equal(array(a), b);
319}
320template <typename T>
321array operator==(const array& a, T b) {
322 return equal(a, array(b));
323}
324
326array not_equal(const array& a, const array& b, StreamOrDevice s = {});
327inline array operator!=(const array& a, const array& b) {
328 return not_equal(a, b);
329}
330template <typename T>
331array operator!=(T a, const array& b) {
332 return not_equal(array(a), b);
333}
334template <typename T>
335array operator!=(const array& a, T b) {
336 return not_equal(a, array(b));
337}
338
340array greater(const array& a, const array& b, StreamOrDevice s = {});
341inline array operator>(const array& a, const array& b) {
342 return greater(a, b);
343}
344template <typename T>
345array operator>(T a, const array& b) {
346 return greater(array(a), b);
347}
348template <typename T>
349array operator>(const array& a, T b) {
350 return greater(a, array(b));
351}
352
354array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
355inline array operator>=(const array& a, const array& b) {
356 return greater_equal(a, b);
357}
358template <typename T>
359array operator>=(T a, const array& b) {
360 return greater_equal(array(a), b);
361}
362template <typename T>
363array operator>=(const array& a, T b) {
364 return greater_equal(a, array(b));
365}
366
368array less(const array& a, const array& b, StreamOrDevice s = {});
369inline array operator<(const array& a, const array& b) {
370 return less(a, b);
371}
372template <typename T>
373array operator<(T a, const array& b) {
374 return less(array(a), b);
375}
376template <typename T>
377array operator<(const array& a, T b) {
378 return less(a, array(b));
379}
380
382array less_equal(const array& a, const array& b, StreamOrDevice s = {});
383inline array operator<=(const array& a, const array& b) {
384 return less_equal(a, b);
385}
386template <typename T>
387array operator<=(T a, const array& b) {
388 return less_equal(array(a), b);
389}
390template <typename T>
391array operator<=(const array& a, T b) {
392 return less_equal(a, array(b));
393}
394
397 const array& a,
398 const array& b,
399 bool equal_nan,
400 StreamOrDevice s = {});
401inline array
402array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
403 return array_equal(a, b, false, s);
404}
405
406array isnan(const array& a, StreamOrDevice s = {});
407
408array isinf(const array& a, StreamOrDevice s = {});
409
411
413
415
418 const array& condition,
419 const array& x,
420 const array& y,
421 StreamOrDevice s = {});
422
425 const array& a,
426 float nan = 0.0f,
427 const std::optional<float> posinf = std::nullopt,
428 const std::optional<float> neginf = std::nullopt,
429 StreamOrDevice s = {});
430
432array all(const array& a, bool keepdims, StreamOrDevice s = {});
433inline array all(const array& a, StreamOrDevice s = {}) {
434 return all(a, false, to_stream(s));
435}
436
439 const array& a,
440 const array& b,
441 double rtol = 1e-5,
442 double atol = 1e-8,
443 bool equal_nan = false,
444 StreamOrDevice s = {});
445
449 const array& a,
450 const array& b,
451 double rtol = 1e-5,
452 double atol = 1e-8,
453 bool equal_nan = false,
454 StreamOrDevice s = {});
455
461 const array& a,
462 const std::vector<int>& axes,
463 bool keepdims = false,
464 StreamOrDevice s = {});
465
471 const array& a,
472 int axis,
473 bool keepdims = false,
474 StreamOrDevice s = {});
475
477array any(const array& a, bool keepdims, StreamOrDevice s = {});
478inline array any(const array& a, StreamOrDevice s = {}) {
479 return any(a, false, to_stream(s));
480}
481
487 const array& a,
488 const std::vector<int>& axes,
489 bool keepdims = false,
490 StreamOrDevice s = {});
491
497 const array& a,
498 int axis,
499 bool keepdims = false,
500 StreamOrDevice s = {});
501
503array sum(const array& a, bool keepdims, StreamOrDevice s = {});
504inline array sum(const array& a, StreamOrDevice s = {}) {
505 return sum(a, false, to_stream(s));
506}
507
510 const array& a,
511 const std::vector<int>& axes,
512 bool keepdims = false,
513 StreamOrDevice s = {});
514
517 const array& a,
518 int axis,
519 bool keepdims = false,
520 StreamOrDevice s = {});
521
523array mean(const array& a, bool keepdims, StreamOrDevice s = {});
524inline array mean(const array& a, StreamOrDevice s = {}) {
525 return mean(a, false, to_stream(s));
526}
527
530 const array& a,
531 const std::vector<int>& axes,
532 bool keepdims = false,
533 StreamOrDevice s = {});
534
537 const array& a,
538 int axis,
539 bool keepdims = false,
540 StreamOrDevice s = {});
541
543array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
544inline array var(const array& a, StreamOrDevice s = {}) {
545 return var(a, false, 0, to_stream(s));
546}
547
551 const array& a,
552 const std::vector<int>& axes,
553 bool keepdims = false,
554 int ddof = 0,
555 StreamOrDevice s = {});
556
560 const array& a,
561 int axis,
562 bool keepdims = false,
563 int ddof = 0,
564 StreamOrDevice s = {});
565
567array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
568inline array std(const array& a, StreamOrDevice s = {}) {
569 return std(a, false, 0, to_stream(s));
570}
571
575 const array& a,
576 const std::vector<int>& axes,
577 bool keepdims = false,
578 int ddof = 0,
579 StreamOrDevice s = {});
580
584 const array& a,
585 int axis,
586 bool keepdims = false,
587 int ddof = 0,
588 StreamOrDevice s = {});
589
591array prod(const array& a, bool keepdims, StreamOrDevice s = {});
592inline array prod(const array& a, StreamOrDevice s = {}) {
593 return prod(a, false, to_stream(s));
594}
595
598 const array& a,
599 const std::vector<int>& axes,
600 bool keepdims = false,
601 StreamOrDevice s = {});
602
605 const array& a,
606 int axis,
607 bool keepdims = false,
608 StreamOrDevice s = {});
609
611array max(const array& a, bool keepdims, StreamOrDevice s = {});
612inline array max(const array& a, StreamOrDevice s = {}) {
613 return max(a, false, to_stream(s));
614}
615
618 const array& a,
619 const std::vector<int>& axes,
620 bool keepdims = false,
621 StreamOrDevice s = {});
622
625 const array& a,
626 int axis,
627 bool keepdims = false,
628 StreamOrDevice s = {});
629
631array min(const array& a, bool keepdims, StreamOrDevice s = {});
632inline array min(const array& a, StreamOrDevice s = {}) {
633 return min(a, false, to_stream(s));
634}
635
638 const array& a,
639 const std::vector<int>& axes,
640 bool keepdims = false,
641 StreamOrDevice s = {});
642
645 const array& a,
646 int axis,
647 bool keepdims = false,
648 StreamOrDevice s = {});
649
651array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
652inline array argmin(const array& a, StreamOrDevice s = {}) {
653 return argmin(a, false, s);
654}
655
658 const array& a,
659 int axis,
660 bool keepdims = false,
661 StreamOrDevice s = {});
662
664array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
665inline array argmax(const array& a, StreamOrDevice s = {}) {
666 return argmax(a, false, s);
667}
668
671 const array& a,
672 int axis,
673 bool keepdims = false,
674 StreamOrDevice s = {});
675
677array sort(const array& a, StreamOrDevice s = {});
678
680array sort(const array& a, int axis, StreamOrDevice s = {});
681
684
686array argsort(const array& a, int axis, StreamOrDevice s = {});
687
692array partition(const array& a, int kth, StreamOrDevice s = {});
693
698array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
699
704array argpartition(const array& a, int kth, StreamOrDevice s = {});
705
710array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
711
713array topk(const array& a, int k, StreamOrDevice s = {});
714
716array topk(const array& a, int k, int axis, StreamOrDevice s = {});
717
719array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
720inline array logsumexp(const array& a, StreamOrDevice s = {}) {
721 return logsumexp(a, false, to_stream(s));
722}
723
726 const array& a,
727 const std::vector<int>& axes,
728 bool keepdims = false,
729 StreamOrDevice s = {});
730
733 const array& a,
734 int axis,
735 bool keepdims = false,
736 StreamOrDevice s = {});
737
739array abs(const array& a, StreamOrDevice s = {});
740
744
746array sign(const array& a, StreamOrDevice s = {});
747
750
752array logical_and(const array& a, const array& b, StreamOrDevice s = {});
753array operator&&(const array& a, const array& b);
754
756array logical_or(const array& a, const array& b, StreamOrDevice s = {});
757array operator||(const array& a, const array& b);
758
761
763array add(const array& a, const array& b, StreamOrDevice s = {});
764array operator+(const array& a, const array& b);
765template <typename T>
766array operator+(T a, const array& b) {
767 return add(array(a), b);
768}
769template <typename T>
770array operator+(const array& a, T b) {
771 return add(a, array(b));
772}
773
775array subtract(const array& a, const array& b, StreamOrDevice s = {});
776array operator-(const array& a, const array& b);
777template <typename T>
778array operator-(T a, const array& b) {
779 return subtract(array(a), b);
780}
781template <typename T>
782array operator-(const array& a, T b) {
783 return subtract(a, array(b));
784}
785
787array multiply(const array& a, const array& b, StreamOrDevice s = {});
788array operator*(const array& a, const array& b);
789template <typename T>
790array operator*(T a, const array& b) {
791 return multiply(array(a), b);
792}
793template <typename T>
794array operator*(const array& a, T b) {
795 return multiply(a, array(b));
796}
797
799array divide(const array& a, const array& b, StreamOrDevice s = {});
800array operator/(const array& a, const array& b);
801array operator/(double a, const array& b);
802array operator/(const array& a, double b);
803
805std::vector<array>
806divmod(const array& a, const array& b, StreamOrDevice s = {});
807
809array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
810
812array remainder(const array& a, const array& b, StreamOrDevice s = {});
813array operator%(const array& a, const array& b);
814template <typename T>
815array operator%(T a, const array& b) {
816 return remainder(array(a), b);
817}
818template <typename T>
819array operator%(const array& a, T b) {
820 return remainder(a, array(b));
821}
822
824array maximum(const array& a, const array& b, StreamOrDevice s = {});
825
827array minimum(const array& a, const array& b, StreamOrDevice s = {});
828
830array floor(const array& a, StreamOrDevice s = {});
831
833array ceil(const array& a, StreamOrDevice s = {});
834
837
839array exp(const array& a, StreamOrDevice s = {});
840
842array sin(const array& a, StreamOrDevice s = {});
843
845array cos(const array& a, StreamOrDevice s = {});
846
848array tan(const array& a, StreamOrDevice s = {});
849
852
855
858
860array arctan2(const array& a, const array& b, StreamOrDevice s = {});
861
863array sinh(const array& a, StreamOrDevice s = {});
864
866array cosh(const array& a, StreamOrDevice s = {});
867
869array tanh(const array& a, StreamOrDevice s = {});
870
873
876
879
882
885
887array log(const array& a, StreamOrDevice s = {});
888
890array log2(const array& a, StreamOrDevice s = {});
891
893array log10(const array& a, StreamOrDevice s = {});
894
896array log1p(const array& a, StreamOrDevice s = {});
897
899array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
900
903
905array erf(const array& a, StreamOrDevice s = {});
906
909
911array expm1(const array& a, StreamOrDevice s = {});
912
915
917array round(const array& a, int decimals, StreamOrDevice s = {});
918inline array round(const array& a, StreamOrDevice s = {}) {
919 return round(a, 0, s);
920}
921
923array matmul(const array& a, const array& b, StreamOrDevice s = {});
924
927 const array& a,
928 const std::vector<array>& indices,
929 const std::vector<int>& axes,
930 const Shape& slice_sizes,
931 StreamOrDevice s = {});
933 const array& a,
934 const array& indices,
935 int axis,
936 const Shape& slice_sizes,
937 StreamOrDevice s = {}) {
938 return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
939}
940
942array kron(const array& a, const array& b, StreamOrDevice s = {});
943
946 const array& a,
947 const array& indices,
948 int axis,
949 StreamOrDevice s = {});
950array take(const array& a, int index, int axis, StreamOrDevice s = {});
951
953array take(const array& a, const array& indices, StreamOrDevice s = {});
954array take(const array& a, int index, StreamOrDevice s = {});
955
958 const array& a,
959 const array& indices,
960 int axis,
961 StreamOrDevice s = {});
962
965 const array& a,
966 const array& indices,
967 const array& values,
968 int axis,
969 StreamOrDevice s = {});
970
973 const array& a,
974 const array& indices,
975 const array& values,
976 int axis,
977 StreamOrDevice s = {});
978
1078 const array& a,
1079 const std::vector<array>& indices,
1080 const array& updates,
1081 const std::vector<int>& axes,
1082 StreamOrDevice s = {});
1084 const array& a,
1085 const array& indices,
1086 const array& updates,
1087 int axis,
1088 StreamOrDevice s = {}) {
1089 return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
1090}
1091
1094 const array& a,
1095 const std::vector<array>& indices,
1096 const array& updates,
1097 const std::vector<int>& axes,
1098 StreamOrDevice s = {});
1100 const array& a,
1101 const array& indices,
1102 const array& updates,
1103 int axis,
1104 StreamOrDevice s = {}) {
1105 return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
1106}
1107
1110 const array& a,
1111 const std::vector<array>& indices,
1112 const array& updates,
1113 const std::vector<int>& axes,
1114 StreamOrDevice s = {});
1116 const array& a,
1117 const array& indices,
1118 const array& updates,
1119 int axis,
1120 StreamOrDevice s = {}) {
1121 return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
1122}
1123
1126 const array& a,
1127 const std::vector<array>& indices,
1128 const array& updates,
1129 const std::vector<int>& axes,
1130 StreamOrDevice s = {});
1132 const array& a,
1133 const array& indices,
1134 const array& updates,
1135 int axis,
1136 StreamOrDevice s = {}) {
1137 return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
1138}
1139
1141 const array& a,
1142 const std::vector<array>& indices,
1143 const array& updates,
1144 const std::vector<int>& axes,
1145 StreamOrDevice s = {});
1147 const array& a,
1148 const array& indices,
1149 const array& updates,
1150 int axis,
1151 StreamOrDevice s = {}) {
1152 return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
1153}
1154
1156array sqrt(const array& a, StreamOrDevice s = {});
1157
1160
1163 const array& a,
1164 const std::vector<int>& axes,
1165 bool precise = false,
1166 StreamOrDevice s = {});
1167
1169array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
1170
1172inline array
1173softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
1174 return softmax(a, std::vector<int>{axis}, precise, s);
1175}
1176
1178array power(const array& a, const array& b, StreamOrDevice s = {});
1179
1182 const array& a,
1183 int axis,
1184 bool reverse = false,
1185 bool inclusive = true,
1186 StreamOrDevice s = {});
1187
1190 const array& a,
1191 int axis,
1192 bool reverse = false,
1193 bool inclusive = true,
1194 StreamOrDevice s = {});
1195
1198 const array& a,
1199 int axis,
1200 bool reverse = false,
1201 bool inclusive = true,
1202 StreamOrDevice s = {});
1203
1206 const array& a,
1207 int axis,
1208 bool reverse = false,
1209 bool inclusive = true,
1210 StreamOrDevice s = {});
1211
1214 array input,
1215 array weight,
1216 std::vector<int> stride = {},
1217 std::vector<int> padding_lo = {},
1218 std::vector<int> padding_hi = {},
1219 std::vector<int> kernel_dilation = {},
1220 std::vector<int> input_dilation = {},
1221 int groups = 1,
1222 bool flip = false,
1223 StreamOrDevice s = {});
1224
1227 const array& input,
1228 const array& weight,
1229 std::vector<int> stride = {},
1230 std::vector<int> padding = {},
1231 std::vector<int> kernel_dilation = {},
1232 std::vector<int> input_dilation = {},
1233 int groups = 1,
1234 bool flip = false,
1235 StreamOrDevice s = {}) {
1236 return conv_general(
1237 /* const array& input = */ input,
1238 /* const array& weight = */ weight,
1239 /* std::vector<int> stride = */ stride,
1240 /* std::vector<int> padding_lo = */ padding,
1241 /* std::vector<int> padding_hi = */ padding,
1242 /* std::vector<int> kernel_dilation = */ kernel_dilation,
1243 /* std::vector<int> input_dilation = */ input_dilation,
1244 /* int groups = */ groups,
1245 /* bool flip = */ flip,
1246 /* StreamOrDevice s = */ s);
1247}
1248
1251 const array& input,
1252 const array& weight,
1253 int stride = 1,
1254 int padding = 0,
1255 int dilation = 1,
1256 int groups = 1,
1257 StreamOrDevice s = {});
1258
1261 const array& input,
1262 const array& weight,
1263 const std::pair<int, int>& stride = {1, 1},
1264 const std::pair<int, int>& padding = {0, 0},
1265 const std::pair<int, int>& dilation = {1, 1},
1266 int groups = 1,
1267 StreamOrDevice s = {});
1268
1271 const array& input,
1272 const array& weight,
1273 const std::tuple<int, int, int>& stride = {1, 1, 1},
1274 const std::tuple<int, int, int>& padding = {0, 0, 0},
1275 const std::tuple<int, int, int>& dilation = {1, 1, 1},
1276 int groups = 1,
1277 StreamOrDevice s = {});
1278
1281 const array& input,
1282 const array& weight,
1283 int stride = 1,
1284 int padding = 0,
1285 int dilation = 1,
1286 int groups = 1,
1287 StreamOrDevice s = {});
1288
1291 const array& input,
1292 const array& weight,
1293 const std::pair<int, int>& stride = {1, 1},
1294 const std::pair<int, int>& padding = {0, 0},
1295 const std::pair<int, int>& dilation = {1, 1},
1296 int groups = 1,
1297 StreamOrDevice s = {});
1298
1301 const array& input,
1302 const array& weight,
1303 const std::tuple<int, int, int>& stride = {1, 1, 1},
1304 const std::tuple<int, int, int>& padding = {0, 0, 0},
1305 const std::tuple<int, int, int>& dilation = {1, 1, 1},
1306 int groups = 1,
1307 StreamOrDevice s = {});
1308
1311 array x,
1312 array w,
1313 array scales,
1314 array biases,
1315 bool transpose = true,
1316 int group_size = 64,
1317 int bits = 4,
1318 StreamOrDevice s = {});
1319
1321std::tuple<array, array, array> quantize(
1322 const array& w,
1323 int group_size = 64,
1324 int bits = 4,
1325 StreamOrDevice s = {});
1326
1329 const array& w,
1330 const array& scales,
1331 const array& biases,
1332 int group_size = 64,
1333 int bits = 4,
1334 StreamOrDevice s = {});
1335
1338 const array& x,
1339 const array& w,
1340 const array& scales,
1341 const array& biases,
1342 std::optional<array> lhs_indices = std::nullopt,
1343 std::optional<array> rhs_indices = std::nullopt,
1344 bool transpose = true,
1345 int group_size = 64,
1346 int bits = 4,
1347 StreamOrDevice s = {});
1348
1351 const array& a,
1352 const array& b,
1353 const int axis = 2,
1354 StreamOrDevice s = {});
1355
1357 const array& a,
1358 const array& b,
1359 const std::vector<int>& axes_a,
1360 const std::vector<int>& axes_b,
1361 StreamOrDevice s = {});
1362
1364array outer(const array& a, const array& b, StreamOrDevice s = {});
1365
1367array inner(const array& a, const array& b, StreamOrDevice s = {});
1368
1371 array c,
1372 array a,
1373 array b,
1374 const float& alpha = 1.f,
1375 const float& beta = 1.f,
1376 StreamOrDevice s = {});
1377
1380 array a,
1381 array b,
1382 int block_size,
1383 std::optional<array> mask_out = std::nullopt,
1384 std::optional<array> mask_lhs = std::nullopt,
1385 std::optional<array> mask_rhs = std::nullopt,
1386 StreamOrDevice s = {});
1387
1390 array a,
1391 array b,
1392 std::optional<array> lhs_indices = std::nullopt,
1393 std::optional<array> rhs_indices = std::nullopt,
1394 StreamOrDevice s = {});
1395
1398 const array& a,
1399 int offset = 0,
1400 int axis1 = 0,
1401 int axis2 = 1,
1402 StreamOrDevice s = {});
1403
1405array diag(const array& a, int k = 0, StreamOrDevice s = {});
1406
1409 const array& a,
1410 int offset,
1411 int axis1,
1412 int axis2,
1413 Dtype dtype,
1414 StreamOrDevice s = {});
1416 const array& a,
1417 int offset,
1418 int axis1,
1419 int axis2,
1420 StreamOrDevice s = {});
1422
1428std::vector<array> depends(
1429 const std::vector<array>& inputs,
1430 const std::vector<array>& dependencies);
1431
1434std::vector<array> atleast_1d(
1435 const std::vector<array>& a,
1436 StreamOrDevice s = {});
1438std::vector<array> atleast_2d(
1439 const std::vector<array>& a,
1440 StreamOrDevice s = {});
1442std::vector<array> atleast_3d(
1443 const std::vector<array>& a,
1444 StreamOrDevice s = {});
1445
1451 const array& a,
1452 std::vector<int> axes,
1453 bool inverted,
1454 Dtype dtype = int32,
1455 StreamOrDevice s = {});
1456
1458
1460array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
1461array operator&(const array& a, const array& b);
1462
1464array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
1465array operator|(const array& a, const array& b);
1466
1468array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
1469array operator^(const array& a, const array& b);
1470
1472array left_shift(const array& a, const array& b, StreamOrDevice s = {});
1473array operator<<(const array& a, const array& b);
1474
1476array right_shift(const array& a, const array& b, StreamOrDevice s = {});
1477array operator>>(const array& a, const array& b);
1478
1482
1483array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
1484
1486array roll(const array& a, int shift, StreamOrDevice s = {});
1487array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
1488array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
1490 const array& a,
1491 int shift,
1492 const std::vector<int>& axes,
1493 StreamOrDevice s = {});
1494array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
1496 const array& a,
1497 const Shape& shift,
1498 const std::vector<int>& axes,
1499 StreamOrDevice s = {});
1500
1501/* The real part of a complex array. */
1502array real(const array& a, StreamOrDevice s = {});
1503
1504/* The imaginary part of a complex array. */
1505array imag(const array& a, StreamOrDevice s = {});
1506
1507/* Ensure the array's underlying memory is contiguous. */
1509 const array& a,
1510 bool allow_col_major = false,
1511 StreamOrDevice s = {});
1512
1514
1515} // namespace mlx::core
Definition array.h:24
array scatter_max(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and max updates to given linear indices.
array floor_divide(const array &a, const array &b, StreamOrDevice s={})
Compute integer division.
array radians(const array &a, StreamOrDevice s={})
Convert the elements of an array from Degrees to Radians.
array reshape(const array &a, Shape shape, StreamOrDevice s={})
Reshape an array to the given shape.
array arccos(const array &a, StreamOrDevice s={})
Arc Cosine of the elements of an array.
array scatter_min(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and min updates to given linear indices.
array less_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a <= b) element-wise.
array cumprod(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative product of an array.
array astype(array a, Dtype dtype, StreamOrDevice s={})
Convert an array to the given data type.
array rsqrt(const array &a, StreamOrDevice s={})
Square root and reciprocal the elements of an array.
array diag(const array &a, int k=0, StreamOrDevice s={})
Extract diagonal from a 2d array or create a diagonal matrix.
array square(const array &a, StreamOrDevice s={})
Square the elements of an array.
array ceil(const array &a, StreamOrDevice s={})
Ceil the element of an array.
array log2(const array &a, StreamOrDevice s={})
Log base 2 of the elements of an array.
array clip(const array &a, const std::optional< array > &a_min=std::nullopt, const std::optional< array > &a_max=std::nullopt, StreamOrDevice s={})
Clip (limit) the values in an array.
array isnan(const array &a, StreamOrDevice s={})
array isneginf(const array &a, StreamOrDevice s={})
array subtract(const array &a, const array &b, StreamOrDevice s={})
Subtract two arrays.
array cummin(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative min of an array.
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with the given value(s).
array log10(const array &a, StreamOrDevice s={})
Log base 10 of the elements of an array.
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
array sign(const array &a, StreamOrDevice s={})
The sign of the elements in an array.
array cosh(const array &a, StreamOrDevice s={})
Hyperbolic Cosine of the elements of an array.
array conv_general(array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
General convolution with a filter.
array logical_or(const array &a, const array &b, StreamOrDevice s={})
Logical or of two arrays.
array moveaxis(const array &a, int source, int destination, StreamOrDevice s={})
Move an axis of an array.
array operator*(const array &a, const array &b)
array operator+(const array &a, const array &b)
array operator||(const array &a, const array &b)
array not_equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a != b) element-wise.
array erf(const array &a, StreamOrDevice s={})
Computes the error function of the elements of an array.
array sqrt(const array &a, StreamOrDevice s={})
Square root the elements of an array.
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array add(const array &a, const array &b, StreamOrDevice s={})
Add two arrays.
array round(const array &a, int decimals, StreamOrDevice s={})
Round a floating point number.
array broadcast_to(const array &a, const Shape &shape, StreamOrDevice s={})
Broadcast an array to a given shape.
array conv1d(const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
1D convolution with a filter
array bitwise_xor(const array &a, const array &b, StreamOrDevice s={})
Bitwise exclusive or.
array equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a == b) element-wise.
array view(const array &a, const Dtype &dtype, StreamOrDevice s={})
array gather_qmm(const array &x, const array &w, const array &scales, const array &biases, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Compute matrix products with matrix-level gather.
array stop_gradient(const array &a, StreamOrDevice s={})
Stop the flow of gradients.
array scatter_prod(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and prod updates to given indices.
array cos(const array &a, StreamOrDevice s={})
Cosine of the elements of an array.
array operator>=(const array &a, const array &b)
Definition ops.h:355
array degrees(const array &a, StreamOrDevice s={})
Convert the elements of an array from Radians to Degrees.
array all(const array &a, bool keepdims, StreamOrDevice s={})
True if all elements in the array are true (or non-zero).
array tan(const array &a, StreamOrDevice s={})
Tangent of the elements of an array.
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere el...
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
array operator>>(const array &a, const array &b)
array minimum(const array &a, const array &b, StreamOrDevice s={})
Element-wise minimum between two arrays.
array prod(const array &a, bool keepdims, StreamOrDevice s={})
The product of all elements of the array.
array atleast_3d(const array &a, StreamOrDevice s={})
array operator<=(const array &a, const array &b)
Definition ops.h:383
array reciprocal(const array &a, StreamOrDevice s={})
The reciprocal (1/x) of the elements in an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array flatten(const array &a, int start_axis, int end_axis=-1, StreamOrDevice s={})
Flatten the dimensions in the range [start_axis, end_axis] .
array isclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
Returns a boolean array where two arrays are element-wise equal within the specified tolerance.
array operator|(const array &a, const array &b)
array topk(const array &a, int k, StreamOrDevice s={})
Returns topk elements of the flattened array.
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
array abs(const array &a, StreamOrDevice s={})
Absolute value of elements in an array.
array conjugate(const array &a, StreamOrDevice s={})
std::vector< array > meshgrid(const std::vector< array > &arrays, bool sparse=false, const std::string &indexing="xy", StreamOrDevice s={})
A vector of coordinate arrays from coordinate vectors.
array tanh(const array &a, StreamOrDevice s={})
Hyperbolic Tangent of the elements of an array.
array as_strided(array a, Shape shape, Strides strides, size_t offset, StreamOrDevice s={})
Create a view of an array with the given shape and strides.
array inner(const array &a, const array &b, StreamOrDevice s={})
Compute the inner product of two vectors.
array unflatten(const array &a, int axis, Shape shape, StreamOrDevice s={})
Unflatten the axis to the given shape.
array block_masked_mm(array a, array b, int block_size, std::optional< array > mask_out=std::nullopt, std::optional< array > mask_lhs=std::nullopt, std::optional< array > mask_rhs=std::nullopt, StreamOrDevice s={})
Compute matrix product with block masking.
array arctan2(const array &a, const array &b, StreamOrDevice s={})
Inverse tangent of the ratio of two arrays.
array number_of_elements(const array &a, std::vector< int > axes, bool inverted, Dtype dtype=int32, StreamOrDevice s={})
Extract the number of elements along some axes as a scalar array.
array kron(const array &a, const array &b, StreamOrDevice s={})
Compute the Kronecker product of two arrays.
array conv3d(const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
3D convolution with a filter
array log(const array &a, StreamOrDevice s={})
Natural logarithm of the elements of an array.
array sigmoid(const array &a, StreamOrDevice s={})
Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).
array squeeze(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Remove singleton dimensions at the given axes.
array greater_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a >= b) element-wise.
array expand_dims(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Add a singleton dimension at the given axes.
array isfinite(const array &a, StreamOrDevice s={})
array conv2d(const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
2D convolution with a filter
array operator>(const array &a, const array &b)
Definition ops.h:341
array bitwise_and(const array &a, const array &b, StreamOrDevice s={})
Bitwise and.
std::vector< array > split(const array &a, int num_splits, int axis, StreamOrDevice s={})
Split an array into sub-arrays along a given axis.
array logical_and(const array &a, const array &b, StreamOrDevice s={})
Logical and of two arrays.
array erfinv(const array &a, StreamOrDevice s={})
Computes the inverse error function of the elements of an array.
array divide(const array &a, const array &b, StreamOrDevice s={})
Divide two arrays.
array power(const array &a, const array &b, StreamOrDevice s={})
Raise elements of a to the power of b element-wise.
array maximum(const array &a, const array &b, StreamOrDevice s={})
Element-wise maximum between two arrays.
array slice_update(const array &src, const array &update, Shape start, Shape stop, Shape strides, StreamOrDevice s={})
Update a slice from the source array.
array argmin(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the minimum value in the array.
array var(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the variance of the elements of an array.
array softmax(const array &a, const std::vector< int > &axes, bool precise=false, StreamOrDevice s={})
Softmax of an array.
array sort(const array &a, StreamOrDevice s={})
Returns a sorted copy of the flattened array.
array max(const array &a, bool keepdims, StreamOrDevice s={})
The maximum of all elements of the array.
array imag(const array &a, StreamOrDevice s={})
array addmm(array c, array a, array b, const float &alpha=1.f, const float &beta=1.f, StreamOrDevice s={})
Compute D = beta * C + alpha * (A @ B)
array tril(array x, int k=0, StreamOrDevice s={})
array operator~(const array &a)
array any(const array &a, bool keepdims, StreamOrDevice s={})
True if any elements in the array are true (or non-zero).
array outer(const array &a, const array &b, StreamOrDevice s={})
Compute the outer product of two vectors.
array hadamard_transform(const array &a, std::optional< float > scale=std::nullopt, StreamOrDevice s={})
Multiply the array by the Hadamard matrix of corresponding size.
array arcsin(const array &a, StreamOrDevice s={})
Arc Sine of the elements of an array.
array left_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the left.
array where(const array &condition, const array &x, const array &y, StreamOrDevice s={})
Select from x or y depending on condition.
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
array contiguous(const array &a, bool allow_col_major=false, StreamOrDevice s={})
array bitwise_or(const array &a, const array &b, StreamOrDevice s={})
Bitwise inclusive or.
array gather_mm(array a, array b, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, StreamOrDevice s={})
Compute matrix product with matrix-level gather.
array floor(const array &a, StreamOrDevice s={})
Floor the element of an array.
array conv_transpose3d(const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
3D transposed convolution with a filter
array argsort(const array &a, StreamOrDevice s={})
Returns indices that sort the flattened array.
array put_along_axis(const array &a, const array &indices, const array &values, int axis, StreamOrDevice s={})
Put the values into the array at the given indices along the axis.
array array_equal(const array &a, const array &b, bool equal_nan, StreamOrDevice s={})
True if two arrays have the same shape and elements.
array isinf(const array &a, StreamOrDevice s={})
array gather(const array &a, const std::vector< array > &indices, const std::vector< int > &axes, const Shape &slice_sizes, StreamOrDevice s={})
Gather array entries given indices and slices.
array less(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a < b) element-wise.
array diagonal(const array &a, int offset=0, int axis1=0, int axis2=1, StreamOrDevice s={})
Extract a diagonal or construct a diagonal array.
array ones_like(const array &a, StreamOrDevice s={})
array negative(const array &a, StreamOrDevice s={})
Negate an array.
array linspace(double start, double stop, int num=50, Dtype dtype=float32, StreamOrDevice s={})
A 1D array of num evenly spaced numbers in the range [start, stop]
array remainder(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise remainder of division.
array arctan(const array &a, StreamOrDevice s={})
Arc Tangent of the elements of an array.
array conv_transpose1d(const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
1D transposed convolution with a filter
std::vector< array > divmod(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise quotient and remainder.
array triu(array x, int k=0, StreamOrDevice s={})
array arccosh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Cosine of the elements of an array.
array tile(const array &arr, std::vector< int > reps, StreamOrDevice s={})
array nan_to_num(const array &a, float nan=0.0f, const std::optional< float > posinf=std::nullopt, const std::optional< float > neginf=std::nullopt, StreamOrDevice s={})
Replace NaN and infinities with finite numbers.
array min(const array &a, bool keepdims, StreamOrDevice s={})
The minimum of all elements of the array.
array operator%(const array &a, const array &b)
array scatter_add_axis(const array &a, const array &indices, const array &values, int axis, StreamOrDevice s={})
Add the values into the array at the given indices along the axis.
std::tuple< array, array, array > quantize(const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
Quantize a matrix along its last axis.
array arctanh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Tangent of the elements of an array.
array repeat(const array &arr, int repeats, int axis, StreamOrDevice s={})
Repeat an array along an axis.
std::vector< array > broadcast_arrays(const std::vector< array > &inputs, StreamOrDevice s={})
Broadcast a vector of arrays against one another.
array pad(const array &a, const std::vector< int > &axes, const Shape &low_pad_size, const Shape &high_pad_size, const array &pad_value=array(0), const std::string &mode="constant", StreamOrDevice s={})
Pad an array with a constant value.
array atleast_1d(const array &a, StreamOrDevice s={})
convert an array to an atleast ndim array
array swapaxes(const array &a, int axis1, int axis2, StreamOrDevice s={})
Swap two axes of an array.
array logical_not(const array &a, StreamOrDevice s={})
Logical not of an array.
array trace(const array &a, int offset, int axis1, int axis2, Dtype dtype, StreamOrDevice s={})
Return the sum along a specified diagonal in the given array.
array quantized_matmul(array x, array w, array scales, array biases, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Quantized matmul multiplies x with a quantized matrix w.
array dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
Dequantize a matrix produced by quantize()
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array partition(const array &a, int kth, StreamOrDevice s={})
Returns a partitioned copy of the flattened array such that the smaller kth elements are first.
array take(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array slices at the given indices of the specified axis.
array operator^(const array &a, const array &b)
array roll(const array &a, int shift, StreamOrDevice s={})
Roll elements along an axis and introduce them on the other side.
std::vector< array > depends(const std::vector< array > &inputs, const std::vector< array > &dependencies)
Implements the identity function but allows injecting dependencies to other arrays.
array arcsinh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Sine of the elements of an array.
array scatter_add(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and add updates to given indices.
array logsumexp(const array &a, bool keepdims, StreamOrDevice s={})
The logsumexp of all elements of the array.
array scatter(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter updates to the given indices.
array operator<<(const array &a, const array &b)
array isposinf(const array &a, StreamOrDevice s={})
array cumsum(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative sum of an array.
array operator-(const array &a)
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
array ones(const Shape &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with ones.
array take_along_axis(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array entries given indices along the axis.
array zeros(const Shape &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with zeros.
array argmax(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the maximum value in the array.
array conv_transpose2d(const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
2D transposed convolution with a filter
array sin(const array &a, StreamOrDevice s={})
Sine of the elements of an array.
array operator&&(const array &a, const array &b)
array cummax(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative max of an array.
array operator<(const array &a, const array &b)
Definition ops.h:369
array atleast_2d(const array &a, StreamOrDevice s={})
array operator/(const array &a, const array &b)
array allclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
True if the two arrays are equal within the specified tolerance.
array operator&(const array &a, const array &b)
array bitwise_invert(const array &a, StreamOrDevice s={})
Invert the bits.
array argpartition(const array &a, int kth, StreamOrDevice s={})
Returns indices that partition the flattened array such that the smaller kth elements are first.
array greater(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a > b) element-wise.
array sinh(const array &a, StreamOrDevice s={})
Hyperbolic Sine of the elements of an array.
array multiply(const array &a, const array &b, StreamOrDevice s={})
Multiply two arrays.
array tensordot(const array &a, const array &b, const int axis=2, StreamOrDevice s={})
Returns a contraction of a and b over multiple dimensions.
array real(const array &a, StreamOrDevice s={})
array stack(const std::vector< array > &arrays, int axis, StreamOrDevice s={})
Stack arrays along a new axis.
array logaddexp(const array &a, const array &b, StreamOrDevice s={})
Log-add-exp of one elements in the array: log(exp(a) + exp(b)).
array right_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the right.
array zeros_like(const array &a, StreamOrDevice s={})
Definition allocator.h:7
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition arange.h:24
Stream to_stream(StreamOrDevice s)
void copy(const array &src, array &dst, CopyType ctype)
void slice(const array &in, array &out, const Shape &start_indices, const Shape &strides)
constexpr Dtype int32
Definition dtype.h:77
constexpr Dtype float32
Definition dtype.h:81
std::vector< ShapeElem > Shape
Definition array.h:21
void concatenate(std::string &acc, T first)
Definition utils.h:62
std::vector< int64_t > Strides
Definition array.h:22
bool operator==(const Device &lhs, const Device &rhs)
bool operator!=(const Device &lhs, const Device &rhs)
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15
void matmul(const array &a, const array &b, array &out, bool a_transposed, bool b_transposed, size_t lda, size_t ldb, float alpha, float beta)
Definition dtype.h:13