MLX
Loading...
Searching...
No Matches
ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <optional>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/stream.h"
10#include "mlx/utils.h"
11
12namespace mlx::core {
13
23 double start,
24 double stop,
25 double step,
26 Dtype dtype,
27 StreamOrDevice s = {});
28array arange(double start, double stop, double step, StreamOrDevice s = {});
29array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
30array arange(double start, double stop, StreamOrDevice s = {});
31array arange(double stop, Dtype dtype, StreamOrDevice s = {});
32array arange(double stop, StreamOrDevice s = {});
33
34array arange(int start, int stop, int step, StreamOrDevice s = {});
35array arange(int start, int stop, StreamOrDevice s = {});
36array arange(int stop, StreamOrDevice s = {});
37
40 double start,
41 double stop,
42 int num = 50,
43 Dtype dtype = float32,
44 StreamOrDevice s = {});
45
48
51 array a,
52 std::vector<int> shape,
53 std::vector<size_t> strides,
54 size_t offset,
55 StreamOrDevice s = {});
56
59
62 std::vector<int> shape,
63 array vals,
64 Dtype dtype,
65 StreamOrDevice s = {});
66array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
67template <typename T>
68array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
69 return full(std::move(shape), array(val, dtype), to_stream(s));
70}
71template <typename T>
72array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
73 return full(std::move(shape), array(val), to_stream(s));
74}
75
77array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
78inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
79 return zeros(shape, float32, s);
80}
82
84array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
85inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
86 return ones(shape, float32, s);
87}
89
92array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
93inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
94 return eye(n, n, 0, dtype, s);
95}
96inline array eye(int n, int m, StreamOrDevice s = {}) {
97 return eye(n, m, 0, float32, s);
98}
99inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
100 return eye(n, m, k, float32, s);
101}
102inline array eye(int n, StreamOrDevice s = {}) {
103 return eye(n, n, 0, float32, s);
104}
105
108array identity(int n, Dtype dtype, StreamOrDevice s = {});
109inline array identity(int n, StreamOrDevice s = {}) {
110 return identity(n, float32, s);
111}
112
113array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
114inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
115 return tri(n, n, 0, type, s);
116}
117
118array tril(array x, int k = 0, StreamOrDevice s = {});
119array triu(array x, int k = 0, StreamOrDevice s = {});
120
122array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
123
126 const array& a,
127 int start_axis,
128 int end_axis = -1,
129 StreamOrDevice s = {});
130
133
136 const array& a,
137 std::optional<float> scale = std::nullopt,
138 StreamOrDevice s = {});
139
142 const array& a,
143 const std::vector<int>& axes,
144 StreamOrDevice s = {});
145
147inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
148 return squeeze(a, std::vector<int>{axis}, s);
149}
150
153
156 const array& a,
157 const std::vector<int>& axes,
158 StreamOrDevice s = {});
159
161array expand_dims(const array& a, int axis, StreamOrDevice s = {});
162
165 const array& a,
166 std::vector<int> start,
167 std::vector<int> stop,
168 std::vector<int> strides,
169 StreamOrDevice s = {});
170
173 const array& a,
174 const std::vector<int>& start,
175 const std::vector<int>& stop,
176 StreamOrDevice s = {});
177
180 const array& src,
181 const array& update,
182 std::vector<int> start,
183 std::vector<int> stop,
184 std::vector<int> strides,
185 StreamOrDevice s = {});
186
189 const array& src,
190 const array& update,
191 std::vector<int> start,
192 std::vector<int> stop,
193 StreamOrDevice s = {});
194
196std::vector<array>
197split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
198std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
199std::vector<array> split(
200 const array& a,
201 const std::vector<int>& indices,
202 int axis,
203 StreamOrDevice s = {});
204std::vector<array>
205split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
206
208std::vector<array> meshgrid(
209 const std::vector<array>& arrays,
210 bool sparse = false,
211 std::string indexing = "xy",
212 StreamOrDevice s = {});
213
218 const array& a,
219 const std::optional<array>& a_min = std::nullopt,
220 const std::optional<array>& a_max = std::nullopt,
221 StreamOrDevice s = {});
222
225 const std::vector<array>& arrays,
226 int axis,
227 StreamOrDevice s = {});
228array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
229
231array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
232array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
233
235array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
236array repeat(const array& arr, int repeats, StreamOrDevice s = {});
237
238array tile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});
239
241array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
243 const array& a,
244 std::initializer_list<int> axes,
245 StreamOrDevice s = {}) {
246 return transpose(a, std::vector<int>(axes), s);
247}
248
250array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
251
254 const array& a,
255 int source,
256 int destination,
257 StreamOrDevice s = {});
258
261 const array& a,
262 const std::vector<int>& axes,
263 const std::vector<int>& low_pad_size,
264 const std::vector<int>& high_pad_size,
265 const array& pad_value = array(0),
266 StreamOrDevice s = {});
267
270 const array& a,
271 const std::vector<std::pair<int, int>>& pad_width,
272 const array& pad_value = array(0),
273 StreamOrDevice s = {});
275 const array& a,
276 const std::pair<int, int>& pad_width,
277 const array& pad_value = array(0),
278 StreamOrDevice s = {});
280 const array& a,
281 int pad_width,
282 const array& pad_value = array(0),
283 StreamOrDevice s = {});
284
287
290 const array& a,
291 const std::vector<int>& shape,
292 StreamOrDevice s = {});
293
295std::vector<array> broadcast_arrays(
296 const std::vector<array>& inputs,
297 StreamOrDevice s = {});
298
300array equal(const array& a, const array& b, StreamOrDevice s = {});
301inline array operator==(const array& a, const array& b) {
302 return equal(a, b);
303}
304template <typename T>
305array operator==(T a, const array& b) {
306 return equal(array(a), b);
307}
308template <typename T>
309array operator==(const array& a, T b) {
310 return equal(a, array(b));
311}
312
314array not_equal(const array& a, const array& b, StreamOrDevice s = {});
315inline array operator!=(const array& a, const array& b) {
316 return not_equal(a, b);
317}
318template <typename T>
319array operator!=(T a, const array& b) {
320 return not_equal(array(a), b);
321}
322template <typename T>
323array operator!=(const array& a, T b) {
324 return not_equal(a, array(b));
325}
326
328array greater(const array& a, const array& b, StreamOrDevice s = {});
329inline array operator>(const array& a, const array& b) {
330 return greater(a, b);
331}
332template <typename T>
333array operator>(T a, const array& b) {
334 return greater(array(a), b);
335}
336template <typename T>
337array operator>(const array& a, T b) {
338 return greater(a, array(b));
339}
340
342array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
343inline array operator>=(const array& a, const array& b) {
344 return greater_equal(a, b);
345}
346template <typename T>
347array operator>=(T a, const array& b) {
348 return greater_equal(array(a), b);
349}
350template <typename T>
351array operator>=(const array& a, T b) {
352 return greater_equal(a, array(b));
353}
354
356array less(const array& a, const array& b, StreamOrDevice s = {});
357inline array operator<(const array& a, const array& b) {
358 return less(a, b);
359}
360template <typename T>
361array operator<(T a, const array& b) {
362 return less(array(a), b);
363}
364template <typename T>
365array operator<(const array& a, T b) {
366 return less(a, array(b));
367}
368
370array less_equal(const array& a, const array& b, StreamOrDevice s = {});
371inline array operator<=(const array& a, const array& b) {
372 return less_equal(a, b);
373}
374template <typename T>
375array operator<=(T a, const array& b) {
376 return less_equal(array(a), b);
377}
378template <typename T>
379array operator<=(const array& a, T b) {
380 return less_equal(a, array(b));
381}
382
385 const array& a,
386 const array& b,
387 bool equal_nan,
388 StreamOrDevice s = {});
389inline array
390array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
391 return array_equal(a, b, false, s);
392}
393
394array isnan(const array& a, StreamOrDevice s = {});
395
396array isinf(const array& a, StreamOrDevice s = {});
397
399
401
404 const array& condition,
405 const array& x,
406 const array& y,
407 StreamOrDevice s = {});
408
411 const array& a,
412 float nan = 0.0f,
413 const std::optional<float>& posinf = std::nullopt,
414 const std::optional<float>& neginf = std::nullopt,
415 StreamOrDevice s = {});
416
418array all(const array& a, bool keepdims, StreamOrDevice s = {});
419inline array all(const array& a, StreamOrDevice s = {}) {
420 return all(a, false, to_stream(s));
421}
422
425 const array& a,
426 const array& b,
427 double rtol = 1e-5,
428 double atol = 1e-8,
429 bool equal_nan = false,
430 StreamOrDevice s = {});
431
435 const array& a,
436 const array& b,
437 double rtol = 1e-5,
438 double atol = 1e-8,
439 bool equal_nan = false,
440 StreamOrDevice s = {});
441
447 const array& a,
448 const std::vector<int>& axes,
449 bool keepdims = false,
450 StreamOrDevice s = {});
451
457 const array& a,
458 int axis,
459 bool keepdims = false,
460 StreamOrDevice s = {});
461
463array any(const array& a, bool keepdims, StreamOrDevice s = {});
464inline array any(const array& a, StreamOrDevice s = {}) {
465 return any(a, false, to_stream(s));
466}
467
473 const array& a,
474 const std::vector<int>& axes,
475 bool keepdims = false,
476 StreamOrDevice s = {});
477
483 const array& a,
484 int axis,
485 bool keepdims = false,
486 StreamOrDevice s = {});
487
489array sum(const array& a, bool keepdims, StreamOrDevice s = {});
490inline array sum(const array& a, StreamOrDevice s = {}) {
491 return sum(a, false, to_stream(s));
492}
493
496 const array& a,
497 const std::vector<int>& axes,
498 bool keepdims = false,
499 StreamOrDevice s = {});
500
503 const array& a,
504 int axis,
505 bool keepdims = false,
506 StreamOrDevice s = {});
507
509array mean(const array& a, bool keepdims, StreamOrDevice s = {});
510inline array mean(const array& a, StreamOrDevice s = {}) {
511 return mean(a, false, to_stream(s));
512}
513
516 const array& a,
517 const std::vector<int>& axes,
518 bool keepdims = false,
519 StreamOrDevice s = {});
520
523 const array& a,
524 int axis,
525 bool keepdims = false,
526 StreamOrDevice s = {});
527
529array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
530inline array var(const array& a, StreamOrDevice s = {}) {
531 return var(a, false, 0, to_stream(s));
532}
533
537 const array& a,
538 const std::vector<int>& axes,
539 bool keepdims = false,
540 int ddof = 0,
541 StreamOrDevice s = {});
542
546 const array& a,
547 int axis,
548 bool keepdims = false,
549 int ddof = 0,
550 StreamOrDevice s = {});
551
553array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
554inline array std(const array& a, StreamOrDevice s = {}) {
555 return std(a, false, 0, to_stream(s));
556}
557
561 const array& a,
562 const std::vector<int>& axes,
563 bool keepdims = false,
564 int ddof = 0,
565 StreamOrDevice s = {});
566
570 const array& a,
571 int axis,
572 bool keepdims = false,
573 int ddof = 0,
574 StreamOrDevice s = {});
575
577array prod(const array& a, bool keepdims, StreamOrDevice s = {});
578inline array prod(const array& a, StreamOrDevice s = {}) {
579 return prod(a, false, to_stream(s));
580}
581
584 const array& a,
585 const std::vector<int>& axes,
586 bool keepdims = false,
587 StreamOrDevice s = {});
588
591 const array& a,
592 int axis,
593 bool keepdims = false,
594 StreamOrDevice s = {});
595
597array max(const array& a, bool keepdims, StreamOrDevice s = {});
598inline array max(const array& a, StreamOrDevice s = {}) {
599 return max(a, false, to_stream(s));
600}
601
604 const array& a,
605 const std::vector<int>& axes,
606 bool keepdims = false,
607 StreamOrDevice s = {});
608
611 const array& a,
612 int axis,
613 bool keepdims = false,
614 StreamOrDevice s = {});
615
617array min(const array& a, bool keepdims, StreamOrDevice s = {});
618inline array min(const array& a, StreamOrDevice s = {}) {
619 return min(a, false, to_stream(s));
620}
621
624 const array& a,
625 const std::vector<int>& axes,
626 bool keepdims = false,
627 StreamOrDevice s = {});
628
631 const array& a,
632 int axis,
633 bool keepdims = false,
634 StreamOrDevice s = {});
635
637array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
638inline array argmin(const array& a, StreamOrDevice s = {}) {
639 return argmin(a, false, s);
640}
641
644 const array& a,
645 int axis,
646 bool keepdims = false,
647 StreamOrDevice s = {});
648
650array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
651inline array argmax(const array& a, StreamOrDevice s = {}) {
652 return argmax(a, false, s);
653}
654
657 const array& a,
658 int axis,
659 bool keepdims = false,
660 StreamOrDevice s = {});
661
663array sort(const array& a, StreamOrDevice s = {});
664
666array sort(const array& a, int axis, StreamOrDevice s = {});
667
670
672array argsort(const array& a, int axis, StreamOrDevice s = {});
673
678array partition(const array& a, int kth, StreamOrDevice s = {});
679
684array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
685
690array argpartition(const array& a, int kth, StreamOrDevice s = {});
691
696array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
697
699array topk(const array& a, int k, StreamOrDevice s = {});
700
702array topk(const array& a, int k, int axis, StreamOrDevice s = {});
703
705array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
706inline array logsumexp(const array& a, StreamOrDevice s = {}) {
707 return logsumexp(a, false, to_stream(s));
708}
709
712 const array& a,
713 const std::vector<int>& axes,
714 bool keepdims = false,
715 StreamOrDevice s = {});
716
719 const array& a,
720 int axis,
721 bool keepdims = false,
722 StreamOrDevice s = {});
723
725array abs(const array& a, StreamOrDevice s = {});
726
730
732array sign(const array& a, StreamOrDevice s = {});
733
736
738array logical_and(const array& a, const array& b, StreamOrDevice s = {});
739array operator&&(const array& a, const array& b);
740
742array logical_or(const array& a, const array& b, StreamOrDevice s = {});
743array operator||(const array& a, const array& b);
744
747
749array add(const array& a, const array& b, StreamOrDevice s = {});
750array operator+(const array& a, const array& b);
751template <typename T>
752array operator+(T a, const array& b) {
753 return add(array(a), b);
754}
755template <typename T>
756array operator+(const array& a, T b) {
757 return add(a, array(b));
758}
759
761array subtract(const array& a, const array& b, StreamOrDevice s = {});
762array operator-(const array& a, const array& b);
763template <typename T>
764array operator-(T a, const array& b) {
765 return subtract(array(a), b);
766}
767template <typename T>
768array operator-(const array& a, T b) {
769 return subtract(a, array(b));
770}
771
773array multiply(const array& a, const array& b, StreamOrDevice s = {});
774array operator*(const array& a, const array& b);
775template <typename T>
776array operator*(T a, const array& b) {
777 return multiply(array(a), b);
778}
779template <typename T>
780array operator*(const array& a, T b) {
781 return multiply(a, array(b));
782}
783
785array divide(const array& a, const array& b, StreamOrDevice s = {});
786array operator/(const array& a, const array& b);
787array operator/(double a, const array& b);
788array operator/(const array& a, double b);
789
791std::vector<array>
792divmod(const array& a, const array& b, StreamOrDevice s = {});
793
795array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
796
798array remainder(const array& a, const array& b, StreamOrDevice s = {});
799array operator%(const array& a, const array& b);
800template <typename T>
801array operator%(T a, const array& b) {
802 return remainder(array(a), b);
803}
804template <typename T>
805array operator%(const array& a, T b) {
806 return remainder(a, array(b));
807}
808
810array maximum(const array& a, const array& b, StreamOrDevice s = {});
811
813array minimum(const array& a, const array& b, StreamOrDevice s = {});
814
816array floor(const array& a, StreamOrDevice s = {});
817
819array ceil(const array& a, StreamOrDevice s = {});
820
823
825array exp(const array& a, StreamOrDevice s = {});
826
828array sin(const array& a, StreamOrDevice s = {});
829
831array cos(const array& a, StreamOrDevice s = {});
832
834array tan(const array& a, StreamOrDevice s = {});
835
838
841
844
846array arctan2(const array& a, const array& b, StreamOrDevice s = {});
847
849array sinh(const array& a, StreamOrDevice s = {});
850
852array cosh(const array& a, StreamOrDevice s = {});
853
855array tanh(const array& a, StreamOrDevice s = {});
856
859
862
865
868
871
873array log(const array& a, StreamOrDevice s = {});
874
876array log2(const array& a, StreamOrDevice s = {});
877
879array log10(const array& a, StreamOrDevice s = {});
880
882array log1p(const array& a, StreamOrDevice s = {});
883
885array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
886
889
891array erf(const array& a, StreamOrDevice s = {});
892
895
897array expm1(const array& a, StreamOrDevice s = {});
898
901
903array round(const array& a, int decimals, StreamOrDevice s = {});
904inline array round(const array& a, StreamOrDevice s = {}) {
905 return round(a, 0, s);
906}
907
909array matmul(const array& a, const array& b, StreamOrDevice s = {});
910
913 const array& a,
914 const std::vector<array>& indices,
915 const std::vector<int>& axes,
916 const std::vector<int>& slice_sizes,
917 StreamOrDevice s = {});
919 const array& a,
920 const array& indices,
921 int axis,
922 const std::vector<int>& slice_sizes,
923 StreamOrDevice s = {}) {
924 return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
925}
926
929 const array& a,
930 const array& indices,
931 int axis,
932 StreamOrDevice s = {});
933
935array take(const array& a, const array& indices, StreamOrDevice s = {});
936
939 const array& a,
940 const array& indices,
941 int axis,
942 StreamOrDevice s = {});
943
1043 const array& a,
1044 const std::vector<array>& indices,
1045 const array& updates,
1046 const std::vector<int>& axes,
1047 StreamOrDevice s = {});
1049 const array& a,
1050 const array& indices,
1051 const array& updates,
1052 int axis,
1053 StreamOrDevice s = {}) {
1054 return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
1055}
1056
1059 const array& a,
1060 const std::vector<array>& indices,
1061 const array& updates,
1062 const std::vector<int>& axes,
1063 StreamOrDevice s = {});
1065 const array& a,
1066 const array& indices,
1067 const array& updates,
1068 int axis,
1069 StreamOrDevice s = {}) {
1070 return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
1071}
1072
1075 const array& a,
1076 const std::vector<array>& indices,
1077 const array& updates,
1078 const std::vector<int>& axes,
1079 StreamOrDevice s = {});
1081 const array& a,
1082 const array& indices,
1083 const array& updates,
1084 int axis,
1085 StreamOrDevice s = {}) {
1086 return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
1087}
1088
1091 const array& a,
1092 const std::vector<array>& indices,
1093 const array& updates,
1094 const std::vector<int>& axes,
1095 StreamOrDevice s = {});
1097 const array& a,
1098 const array& indices,
1099 const array& updates,
1100 int axis,
1101 StreamOrDevice s = {}) {
1102 return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
1103}
1106 const array& a,
1107 const std::vector<array>& indices,
1108 const array& updates,
1109 const std::vector<int>& axes,
1110 StreamOrDevice s = {});
1112 const array& a,
1113 const array& indices,
1114 const array& updates,
1115 int axis,
1116 StreamOrDevice s = {}) {
1117 return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
1118}
1119
1121array sqrt(const array& a, StreamOrDevice s = {});
1122
1125
1128 const array& a,
1129 const std::vector<int>& axes,
1130 bool precise = false,
1131 StreamOrDevice s = {});
1132
1134array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
1135
1137inline array
1138softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
1139 return softmax(a, std::vector<int>{axis}, precise, s);
1140}
1141
1143array power(const array& a, const array& b, StreamOrDevice s = {});
1144
1147 const array& a,
1148 int axis,
1149 bool reverse = false,
1150 bool inclusive = true,
1151 StreamOrDevice s = {});
1152
1155 const array& a,
1156 int axis,
1157 bool reverse = false,
1158 bool inclusive = true,
1159 StreamOrDevice s = {});
1160
1163 const array& a,
1164 int axis,
1165 bool reverse = false,
1166 bool inclusive = true,
1167 StreamOrDevice s = {});
1168
1171 const array& a,
1172 int axis,
1173 bool reverse = false,
1174 bool inclusive = true,
1175 StreamOrDevice s = {});
1176
1179 array input,
1180 array weight,
1181 std::vector<int> stride = {},
1182 std::vector<int> padding_lo = {},
1183 std::vector<int> padding_hi = {},
1184 std::vector<int> kernel_dilation = {},
1185 std::vector<int> input_dilation = {},
1186 int groups = 1,
1187 bool flip = false,
1188 StreamOrDevice s = {});
1189
1192 const array& input,
1193 const array& weight,
1194 std::vector<int> stride = {},
1195 std::vector<int> padding = {},
1196 std::vector<int> kernel_dilation = {},
1197 std::vector<int> input_dilation = {},
1198 int groups = 1,
1199 bool flip = false,
1200 StreamOrDevice s = {}) {
1201 return conv_general(
1202 /* const array& input = */ input,
1203 /* const array& weight = */ weight,
1204 /* std::vector<int> stride = */ stride,
1205 /* std::vector<int> padding_lo = */ padding,
1206 /* std::vector<int> padding_hi = */ padding,
1207 /* std::vector<int> kernel_dilation = */ kernel_dilation,
1208 /* std::vector<int> input_dilation = */ input_dilation,
1209 /* int groups = */ groups,
1210 /* bool flip = */ flip,
1211 /* StreamOrDevice s = */ s);
1212}
1213
1216 const array& input,
1217 const array& weight,
1218 int stride = 1,
1219 int padding = 0,
1220 int dilation = 1,
1221 int groups = 1,
1222 StreamOrDevice s = {});
1223
1226 const array& input,
1227 const array& weight,
1228 const std::pair<int, int>& stride = {1, 1},
1229 const std::pair<int, int>& padding = {0, 0},
1230 const std::pair<int, int>& dilation = {1, 1},
1231 int groups = 1,
1232 StreamOrDevice s = {});
1233
1236 const array& input,
1237 const array& weight,
1238 const std::tuple<int, int, int>& stride = {1, 1, 1},
1239 const std::tuple<int, int, int>& padding = {0, 0, 0},
1240 const std::tuple<int, int, int>& dilation = {1, 1, 1},
1241 int groups = 1,
1242 StreamOrDevice s = {});
1243
1246 const array& x,
1247 const array& w,
1248 const array& scales,
1249 const array& biases,
1250 bool transpose = true,
1251 int group_size = 64,
1252 int bits = 4,
1253 StreamOrDevice s = {});
1254
1256std::tuple<array, array, array> quantize(
1257 const array& w,
1258 int group_size = 64,
1259 int bits = 4,
1260 StreamOrDevice s = {});
1261
1264 const array& w,
1265 const array& scales,
1266 const array& biases,
1267 int group_size = 64,
1268 int bits = 4,
1269 StreamOrDevice s = {});
1270
1273 const array& x,
1274 const array& w,
1275 const array& scales,
1276 const array& biases,
1277 std::optional<array> lhs_indices = std::nullopt,
1278 std::optional<array> rhs_indices = std::nullopt,
1279 bool transpose = true,
1280 int group_size = 64,
1281 int bits = 4,
1282 StreamOrDevice s = {});
1283
1286 const array& a,
1287 const array& b,
1288 const int axis = 2,
1289 StreamOrDevice s = {});
1290
1292 const array& a,
1293 const array& b,
1294 const std::vector<int>& axes_a,
1295 const std::vector<int>& axes_b,
1296 StreamOrDevice s = {});
1297
1299array outer(const array& a, const array& b, StreamOrDevice s = {});
1300
1302array inner(const array& a, const array& b, StreamOrDevice s = {});
1303
1306 array c,
1307 array a,
1308 array b,
1309 const float& alpha = 1.f,
1310 const float& beta = 1.f,
1311 StreamOrDevice s = {});
1312
1315 array a,
1316 array b,
1317 int block_size,
1318 std::optional<array> mask_out = std::nullopt,
1319 std::optional<array> mask_lhs = std::nullopt,
1320 std::optional<array> mask_rhs = std::nullopt,
1321 StreamOrDevice s = {});
1322
1325 array a,
1326 array b,
1327 std::optional<array> lhs_indices = std::nullopt,
1328 std::optional<array> rhs_indices = std::nullopt,
1329 StreamOrDevice s = {});
1330
1333 const array& a,
1334 int offset = 0,
1335 int axis1 = 0,
1336 int axis2 = 1,
1337 StreamOrDevice s = {});
1338
1340array diag(const array& a, int k = 0, StreamOrDevice s = {});
1341
1344 const array& a,
1345 int offset,
1346 int axis1,
1347 int axis2,
1348 Dtype dtype,
1349 StreamOrDevice s = {});
1351 const array& a,
1352 int offset,
1353 int axis1,
1354 int axis2,
1355 StreamOrDevice s = {});
1357
1363std::vector<array> depends(
1364 const std::vector<array>& inputs,
1365 const std::vector<array>& dependencies);
1366
1369std::vector<array> atleast_1d(
1370 const std::vector<array>& a,
1371 StreamOrDevice s = {});
1373std::vector<array> atleast_2d(
1374 const std::vector<array>& a,
1375 StreamOrDevice s = {});
1377std::vector<array> atleast_3d(
1378 const std::vector<array>& a,
1379 StreamOrDevice s = {});
1380
1386 const array& a,
1387 std::vector<int> axes,
1388 bool inverted,
1389 Dtype dtype = int32,
1390 StreamOrDevice s = {});
1391
1393
1395array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
1396array operator&(const array& a, const array& b);
1397
1399array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
1400array operator|(const array& a, const array& b);
1401
1403array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
1404array operator^(const array& a, const array& b);
1405
1407array left_shift(const array& a, const array& b, StreamOrDevice s = {});
1408array operator<<(const array& a, const array& b);
1409
1411array right_shift(const array& a, const array& b, StreamOrDevice s = {});
1412array operator>>(const array& a, const array& b);
1413
1414array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
1417} // namespace mlx::core
Definition array.h:20
array scatter_max(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and max updates to given linear indices.
array floor_divide(const array &a, const array &b, StreamOrDevice s={})
Compute integer division.
array radians(const array &a, StreamOrDevice s={})
Convert the elements of an array from Degrees to Radians.
array arccos(const array &a, StreamOrDevice s={})
Arc Cosine of the elements of an array.
array scatter_min(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and min updates to given linear indices.
array less_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a <= b) element-wise.
array cumprod(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative product of an array.
array astype(array a, Dtype dtype, StreamOrDevice s={})
Convert an array to the given data type.
array rsqrt(const array &a, StreamOrDevice s={})
Square root and reciprocal the elements of an array.
array diag(const array &a, int k=0, StreamOrDevice s={})
Extract diagonal from a 2d array or create a diagonal matrix.
array square(const array &a, StreamOrDevice s={})
Square the elements of an array.
array ceil(const array &a, StreamOrDevice s={})
Ceil the element of an array.
array log2(const array &a, StreamOrDevice s={})
Log base 2 of the elements of an array.
array clip(const array &a, const std::optional< array > &a_min=std::nullopt, const std::optional< array > &a_max=std::nullopt, StreamOrDevice s={})
Clip (limit) the values in an array.
array isnan(const array &a, StreamOrDevice s={})
array isneginf(const array &a, StreamOrDevice s={})
array subtract(const array &a, const array &b, StreamOrDevice s={})
Subtract two arrays.
array cummin(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative min of an array.
array log10(const array &a, StreamOrDevice s={})
Log base 10 of the elements of an array.
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
array sign(const array &a, StreamOrDevice s={})
The sign of the elements in an array.
array cosh(const array &a, StreamOrDevice s={})
Hyperbolic Cosine of the elements of an array.
array conv_general(array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
General convolution with a filter.
array logical_or(const array &a, const array &b, StreamOrDevice s={})
Logical or of two arrays.
array moveaxis(const array &a, int source, int destination, StreamOrDevice s={})
Move an axis of an array.
array operator*(const array &a, const array &b)
array operator+(const array &a, const array &b)
array operator||(const array &a, const array &b)
array not_equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a != b) element-wise.
array erf(const array &a, StreamOrDevice s={})
Computes the error function of the elements of an array.
array sqrt(const array &a, StreamOrDevice s={})
Square root the elements of an array.
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
array add(const array &a, const array &b, StreamOrDevice s={})
Add two arrays.
array round(const array &a, int decimals, StreamOrDevice s={})
Round a floating point number.
array conv1d(const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
1D convolution with a filter
array bitwise_xor(const array &a, const array &b, StreamOrDevice s={})
Bitwise exclusive or.
array equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a == b) element-wise.
array zeros(const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with zeros.
array view(const array &a, const Dtype &dtype, StreamOrDevice s={})
array gather_qmm(const array &x, const array &w, const array &scales, const array &biases, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Compute matrix products with matrix-level gather.
array stop_gradient(const array &a, StreamOrDevice s={})
Stop the flow of gradients.
array scatter_prod(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and prod updates to given indices.
array slice_update(const array &src, const array &update, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
Update a slice from the source array.
array cos(const array &a, StreamOrDevice s={})
Cosine of the elements of an array.
array operator>=(const array &a, const array &b)
Definition ops.h:343
array degrees(const array &a, StreamOrDevice s={})
Convert the elements of an array from Radians to Degrees.
array all(const array &a, bool keepdims, StreamOrDevice s={})
True if all elements in the array are true (or non-zero).
array tan(const array &a, StreamOrDevice s={})
Tangent of the elements of an array.
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere el...
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
array operator>>(const array &a, const array &b)
array minimum(const array &a, const array &b, StreamOrDevice s={})
Element-wise minimum between two arrays.
array prod(const array &a, bool keepdims, StreamOrDevice s={})
The product of all elements of the array.
array atleast_3d(const array &a, StreamOrDevice s={})
array operator<=(const array &a, const array &b)
Definition ops.h:371
array reciprocal(const array &a, StreamOrDevice s={})
The reciprocal (1/x) of the elements in an array.
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
array pad(const array &a, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size, const array &pad_value=array(0), StreamOrDevice s={})
Pad an array with a constant value.
array nan_to_num(const array &a, float nan=0.0f, const std::optional< float > &posinf=std::nullopt, const std::optional< float > &neginf=std::nullopt, StreamOrDevice s={})
Replace NaN and infinities with finite numbers.
array flatten(const array &a, int start_axis, int end_axis=-1, StreamOrDevice s={})
Flatten the dimensions in the range [start_axis, end_axis] .
array isclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
Returns a boolean array where two arrays are element-wise equal within the specified tolerance.
array operator|(const array &a, const array &b)
array topk(const array &a, int k, StreamOrDevice s={})
Returns topk elements of the flattened array.
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
array ones(const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with ones.
array abs(const array &a, StreamOrDevice s={})
Absolute value of elements in an array.
std::vector< array > meshgrid(const std::vector< array > &arrays, bool sparse=false, std::string indexing="xy", StreamOrDevice s={})
A vector of coordinate arrays from coordinate vectors.
array conjugate(const array &a, StreamOrDevice s={})
array tanh(const array &a, StreamOrDevice s={})
Hyperbolic Tangent of the elements of an array.
array quantized_matmul(const array &x, const array &w, const array &scales, const array &biases, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Quantized matmul multiplies x with a quantized matrix w.
array inner(const array &a, const array &b, StreamOrDevice s={})
Compute the inner product of two vectors.
array block_masked_mm(array a, array b, int block_size, std::optional< array > mask_out=std::nullopt, std::optional< array > mask_lhs=std::nullopt, std::optional< array > mask_rhs=std::nullopt, StreamOrDevice s={})
Compute matrix product with block masking.
array arctan2(const array &a, const array &b, StreamOrDevice s={})
Inverse tangent of the ratio of two arrays.
array number_of_elements(const array &a, std::vector< int > axes, bool inverted, Dtype dtype=int32, StreamOrDevice s={})
Extract the number of elements along some axes as a scalar array.
array conv3d(const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
3D convolution with a filter
array log(const array &a, StreamOrDevice s={})
Natural logarithm of the elements of an array.
array sigmoid(const array &a, StreamOrDevice s={})
Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).
array squeeze(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Remove singleton dimensions at the given axes.
array greater_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a >= b) element-wise.
array expand_dims(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Add a singleton dimension at the given axes.
array conv2d(const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
2D convolution with a filter
array operator>(const array &a, const array &b)
Definition ops.h:329
array bitwise_and(const array &a, const array &b, StreamOrDevice s={})
Bitwise and.
std::vector< array > split(const array &a, int num_splits, int axis, StreamOrDevice s={})
Split an array into sub-arrays along a given axis.
array matmul(const array &a, const array &b, StreamOrDevice s={})
Matrix-matrix multiplication.
array logical_and(const array &a, const array &b, StreamOrDevice s={})
Logical and of two arrays.
array erfinv(const array &a, StreamOrDevice s={})
Computes the inverse error function of the elements of an array.
array divide(const array &a, const array &b, StreamOrDevice s={})
Divide two arrays.
array power(const array &a, const array &b, StreamOrDevice s={})
Raise elements of a to the power of b element-wise.
array maximum(const array &a, const array &b, StreamOrDevice s={})
Element-wise maximum between two arrays.
array reshape(const array &a, std::vector< int > shape, StreamOrDevice s={})
Reshape an array to the given shape.
array argmin(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the minimum value in the array.
array var(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the variance of the elements of an array.
array full(std::vector< int > shape, array vals, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with the given value(s).
array softmax(const array &a, const std::vector< int > &axes, bool precise=false, StreamOrDevice s={})
Softmax of an array.
array sort(const array &a, StreamOrDevice s={})
Returns a sorted copy of the flattened array.
array max(const array &a, bool keepdims, StreamOrDevice s={})
The maximum of all elements of the array.
array addmm(array c, array a, array b, const float &alpha=1.f, const float &beta=1.f, StreamOrDevice s={})
Compute D = beta * C + alpha * (A @ B)
array tril(array x, int k=0, StreamOrDevice s={})
array any(const array &a, bool keepdims, StreamOrDevice s={})
True if any elements in the array are true (or non-zero).
array outer(const array &a, const array &b, StreamOrDevice s={})
Compute the outer product of two vectors.
array hadamard_transform(const array &a, std::optional< float > scale=std::nullopt, StreamOrDevice s={})
Multiply the array by the Hadamard matrix of corresponding size.
array arcsin(const array &a, StreamOrDevice s={})
Arc Sine of the elements of an array.
array left_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the left.
array where(const array &condition, const array &x, const array &y, StreamOrDevice s={})
Select from x or y depending on condition.
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
array bitwise_or(const array &a, const array &b, StreamOrDevice s={})
Bitwise inclusive or.
array gather_mm(array a, array b, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, StreamOrDevice s={})
Compute matrix product with matrix-level gather.
array floor(const array &a, StreamOrDevice s={})
Floor the element of an array.
array as_strided(array a, std::vector< int > shape, std::vector< size_t > strides, size_t offset, StreamOrDevice s={})
Create a view of an array with the given shape and strides.
array argsort(const array &a, StreamOrDevice s={})
Returns indices that sort the flattened array.
array array_equal(const array &a, const array &b, bool equal_nan, StreamOrDevice s={})
True if two arrays have the same shape and elements.
array isinf(const array &a, StreamOrDevice s={})
array less(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a < b) element-wise.
array diagonal(const array &a, int offset=0, int axis1=0, int axis2=1, StreamOrDevice s={})
Extract a diagonal or construct a diagonal array.
array ones_like(const array &a, StreamOrDevice s={})
array negative(const array &a, StreamOrDevice s={})
Negate an array.
array linspace(double start, double stop, int num=50, Dtype dtype=float32, StreamOrDevice s={})
A 1D array of num evenly spaced numbers in the range [start, stop]
array remainder(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise remainder of division.
array arctan(const array &a, StreamOrDevice s={})
Arc Tangent of the elements of an array.
std::vector< array > divmod(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise quotient and remainder.
array triu(array x, int k=0, StreamOrDevice s={})
array arccosh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Cosine of the elements of an array.
array tile(const array &arr, std::vector< int > reps, StreamOrDevice s={})
array min(const array &a, bool keepdims, StreamOrDevice s={})
The minimum of all elements of the array.
array operator%(const array &a, const array &b)
std::tuple< array, array, array > quantize(const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
Quantize a matrix along its last axis.
array arctanh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Tangent of the elements of an array.
array repeat(const array &arr, int repeats, int axis, StreamOrDevice s={})
Repeat an array along an axis.
array gather(const array &a, const std::vector< array > &indices, const std::vector< int > &axes, const std::vector< int > &slice_sizes, StreamOrDevice s={})
Gather array entries given indices and slices.
std::vector< array > broadcast_arrays(const std::vector< array > &inputs, StreamOrDevice s={})
Broadcast a vector of arrays against one another.
array atleast_1d(const array &a, StreamOrDevice s={})
convert an array to an atleast ndim array
array swapaxes(const array &a, int axis1, int axis2, StreamOrDevice s={})
Swap two axes of an array.
array logical_not(const array &a, StreamOrDevice s={})
Logical not of an array.
array concatenate(const std::vector< array > &arrays, int axis, StreamOrDevice s={})
Concatenate arrays along a given axis.
array trace(const array &a, int offset, int axis1, int axis2, Dtype dtype, StreamOrDevice s={})
Return the sum along a specified diagonal in the given array.
array dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
Dequantize a matrix produced by quantize()
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
array partition(const array &a, int kth, StreamOrDevice s={})
Returns a partitioned copy of the flattened array such that the smaller kth elements are first.
array take(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array slices at the given indices of the specified axis.
array operator^(const array &a, const array &b)
std::vector< array > depends(const std::vector< array > &inputs, const std::vector< array > &dependencies)
Implements the identity function but allows injecting dependencies to other arrays.
array arcsinh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Sine of the elements of an array.
array scatter_add(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and add updates to given indices.
array logsumexp(const array &a, bool keepdims, StreamOrDevice s={})
The logsumexp of all elements of the array.
array broadcast_to(const array &a, const std::vector< int > &shape, StreamOrDevice s={})
Broadcast an array to a given shape.
array scatter(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter updates to the given indices.
array operator<<(const array &a, const array &b)
array slice(const array &a, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
Slice an array.
array isposinf(const array &a, StreamOrDevice s={})
array cumsum(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative sum of an array.
array operator-(const array &a)
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
array take_along_axis(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array entries given indices along the axis.
array argmax(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the maximum value in the array.
array sin(const array &a, StreamOrDevice s={})
Sine of the elements of an array.
array operator&&(const array &a, const array &b)
array cummax(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative max of an array.
array operator<(const array &a, const array &b)
Definition ops.h:357
array atleast_2d(const array &a, StreamOrDevice s={})
array operator/(const array &a, const array &b)
array allclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
True if the two arrays are equal within the specified tolerance.
array operator&(const array &a, const array &b)
array argpartition(const array &a, int kth, StreamOrDevice s={})
Returns indices that partition the flattened array such that the smaller kth elements are first.
array greater(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a > b) element-wise.
array sinh(const array &a, StreamOrDevice s={})
Hyperbolic Sine of the elements of an array.
array multiply(const array &a, const array &b, StreamOrDevice s={})
Multiply two arrays.
array tensordot(const array &a, const array &b, const int axis=2, StreamOrDevice s={})
Returns a contraction of a and b over multiple dimensions.
array stack(const std::vector< array > &arrays, int axis, StreamOrDevice s={})
Stack arrays along a new axis.
array logaddexp(const array &a, const array &b, StreamOrDevice s={})
Log-add-exp of one elements in the array: log(exp(a) + exp(b)).
array right_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the right.
array zeros_like(const array &a, StreamOrDevice s={})
Definition allocator.h:7
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition arange.h:24
Stream to_stream(StreamOrDevice s)
void copy(const array &src, array &dst, CopyType ctype)
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
bool operator==(const Device &lhs, const Device &rhs)
bool operator!=(const Device &lhs, const Device &rhs)
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14
Definition dtype.h:15